pax_global_header00006660000000000000000000000064132032051020014476gustar00rootroot0000000000000052 comment=17ed2331f389253b9dc7f3fddcab66505bfb7115 asyncssh-1.11.1/000077500000000000000000000000001320320510200134125ustar00rootroot00000000000000asyncssh-1.11.1/.coveragerc000066400000000000000000000001631320320510200155330ustar00rootroot00000000000000[run] branch = True [report] exclude_lines = pragma: no cover raise NotImplementedError omit = .tox/* asyncssh-1.11.1/.gitignore000066400000000000000000000001421320320510200153770ustar00rootroot00000000000000.*.swp MANIFEST __pycache__/ *.py[cod] asyncssh.egg-info build/ dist/ docs/Makefile docs/_build/ asyncssh-1.11.1/.travis.yml000066400000000000000000000044151320320510200155270ustar00rootroot00000000000000language: python install: - pip install tox matrix: allow_failures: - python: "nightly" fast_finish: true include: - os: linux sudo: required dist: trusty python: 3.4 before_install: - .travis/build-libsodium.sh env: - TOXENV=py34 - os: linux sudo: required dist: trusty python: 3.5 before_install: - .travis/build-libsodium.sh env: - TOXENV=py35 - os: linux sudo: required dist: trusty python: 3.6 before_install: - .travis/build-libsodium.sh env: - TOXENV=py36 - os: linux sudo: required dist: trusty python: "nightly" # currently points to 3.7-rc before_install: - .travis/build-libsodium.sh env: - TOXENV=py37 - os: osx osx_image: xcode9.1 language: generic env: - CPPFLAGS=-I/usr/local/opt/openssl/include - LDFLAGS=-L/usr/local/opt/openssl/lib -L/usr/local/opt/libffi/lib - PATH=$HOME/.pyenv/bin:/usr/local/opt/openssl/bin:$PATH - TOXENV=py34 before_install: - brew update - brew install libffi libsodium - eval "$(pyenv init -)" - pyenv install 3.4.6 - pyenv local 3.4.6 - pyenv rehash - os: osx osx_image: xcode9.1 language: generic env: - CPPFLAGS=-I/usr/local/opt/openssl/include - LDFLAGS=-L/usr/local/opt/openssl/lib -L/usr/local/opt/libffi/lib - PATH=$HOME/.pyenv/bin:/usr/local/opt/openssl/bin:$PATH - TOXENV=py35 before_install: - brew update - brew install libffi libsodium - eval "$(pyenv init -)" - pyenv install 3.5.3 - pyenv local 3.5.3 - pyenv rehash - os: osx osx_image: xcode9.1 language: generic env: - CPPFLAGS=-I/usr/local/opt/openssl/include - LDFLAGS=-L/usr/local/opt/openssl/lib -L/usr/local/opt/libffi/lib - PATH=$HOME/.pyenv/bin:/usr/local/opt/openssl/bin:$PATH - TOXENV=py36 before_install: - brew update - brew install libffi libsodium - eval "$(pyenv init -)" - pyenv install 3.6.1 - pyenv local 3.6.1 - pyenv rehash script: travis_wait 60 tox asyncssh-1.11.1/.travis/000077500000000000000000000000001320320510200150005ustar00rootroot00000000000000asyncssh-1.11.1/.travis/build-libsodium.sh000077500000000000000000000003151320320510200204220ustar00rootroot00000000000000#!/bin/bash # # Build and install libsodum 1.0.10 # git clone git://github.com/jedisct1/libsodium.git cd libsodium git checkout tags/1.0.10 ./autogen.sh ./configure make && sudo make install sudo ldconfig asyncssh-1.11.1/CONTRIBUTING.rst000066400000000000000000000074001320320510200160540ustar00rootroot00000000000000Contributing to AsyncSSH ======================== Input on AsyncSSH is extremely welcome. Below are some recommendations of the best ways to contribute. Asking questions ---------------- If you have a general question about how to use AsyncSSH, you are welcome to post it to the end-user mailing list at `asyncssh-users@googlegroups.com `_. If you have a question related to the development of AsyncSSH, you can post it to the development mailing list at `asyncssh-dev@googlegroups.com `_. You are also welcome to use the AsyncSSH `issue tracker `_ to ask questions. Reporting bugs -------------- Please use the `issue tracker `_ to report any bugs you find. Before creating a new issue, please check the currently open issues to see if your problem has already been reported. If you create a new issue, please include the version of AsyncSSH you are using, information about the OS you are running on and the installed version of Python and any other libraries that are involved. Please also include detailed information about how to reproduce the problem, including any traceback information you were able to collect or other relevant output. If you have sample code which exhibits the problem, feel free to include that as well. If possible, please test against the latest version of AsyncSSH. Also, if you are testing code in something other than the master branch, it would be helpful to know if you also see the problem in master. Requesting feature enhancements ------------------------------- The `issue tracker `_ should also be used to post feature enhancement requests. While I can't make any promises about what features will be added in the future, suggestions are always welcome! Contributing code ----------------- Before submitting a pull request, please create an issue on the `issue tracker `_ explaining what functionality you'd like to contribute and how it could be used. Discussing the approach you'd like to take up front will make it far more likely I'll be able to accept your changes, or explain what issues might prevent that before you spend a lot of effort. If you find a typo or other small bug in the code, you're welcome to submit a patch without filing an issue first, but for anything larger than a few lines I strongly recommend coordinating up front. Any code you submit will need to be provided with a compatible license. AsyncSSH code is currently released under the `Eclipse Public License v1.0 `_. Before submitting a pull request, make sure to indicate that you are ok with releasing your code under this license and how you'd like to be listed in the contributors list. Branches -------- There are two long-lived branches in AsyncSSH at the moment: * The master branch is intended to contain the latest stable version of the code. All official versions of AsyncSSH are released from this branch, and each release has a corresponding tag added matching its release number. Bug fixes and simple improvements may be checked directly into this branch, but most new features will be added to the develop branch first. * The develop branch is intended to contain features for developers to test before they are ready to be added to an official release. APIs in the develop branch may be subject to change until they are migrated back to master, and there's no guarantee of backward compatibility in this branch. However, pulling from this branch will provide early access to new functionality and a chance to influence this functionality before it is released. asyncssh-1.11.1/COPYRIGHT000066400000000000000000000006011320320510200147020ustar00rootroot00000000000000Copyright (c) 2013-2017 by Ron Frederick . All rights reserved. This program and the accompanying materials are made available under the terms of the Eclipse Public License v1.0 which accompanies this distribution and is available at: http://www.eclipse.org/legal/epl-v10.html Contributors: Ron Frederick - initial implementation, API, and documentation asyncssh-1.11.1/LICENSE000066400000000000000000000263721320320510200144310ustar00rootroot00000000000000Eclipse Public License - v 1.0 THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. 1. DEFINITIONS "Contribution" means: a) in the case of the initial Contributor, the initial code and documentation distributed under this Agreement, and b) in the case of each subsequent Contributor: i) changes to the Program, and ii) additions to the Program; where such changes and/or additions to the Program originate from and are distributed by that particular Contributor. A Contribution 'originates' from a Contributor if it was added to the Program by such Contributor itself or anyone acting on such Contributor's behalf. Contributions do not include additions to the Program which: (i) are separate modules of software distributed in conjunction with the Program under their own license agreement, and (ii) are not derivative works of the Program. "Contributor" means any person or entity that distributes the Program. "Licensed Patents" mean patent claims licensable by a Contributor which are necessarily infringed by the use or sale of its Contribution alone or when combined with the Program. "Program" means the Contributions distributed in accordance with this Agreement. "Recipient" means anyone who receives the Program under this Agreement, including all Contributors. 2. GRANT OF RIGHTS a) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, distribute and sublicense the Contribution of such Contributor, if any, and such derivative works, in source code and object code form. b) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free patent license under Licensed Patents to make, use, sell, offer to sell, import and otherwise transfer the Contribution of such Contributor, if any, in source code and object code form. This patent license shall apply to the combination of the Contribution and the Program if, at the time the Contribution is added by the Contributor, such addition of the Contribution causes such combination to be covered by the Licensed Patents. The patent license shall not apply to any other combinations which include the Contribution. No hardware per se is licensed hereunder. c) Recipient understands that although each Contributor grants the licenses to its Contributions set forth herein, no assurances are provided by any Contributor that the Program does not infringe the patent or other intellectual property rights of any other entity. Each Contributor disclaims any liability to Recipient for claims brought by any other entity based on infringement of intellectual property rights or otherwise. As a condition to exercising the rights and licenses granted hereunder, each Recipient hereby assumes sole responsibility to secure any other intellectual property rights needed, if any. For example, if a third party patent license is required to allow Recipient to distribute the Program, it is Recipient's responsibility to acquire that license before distributing the Program. d) Each Contributor represents that to its knowledge it has sufficient copyright rights in its Contribution, if any, to grant the copyright license set forth in this Agreement. 3. REQUIREMENTS A Contributor may choose to distribute the Program in object code form under its own license agreement, provided that: a) it complies with the terms and conditions of this Agreement; and b) its license agreement: i) effectively disclaims on behalf of all Contributors all warranties and conditions, express and implied, including warranties or conditions of title and non-infringement, and implied warranties or conditions of merchantability and fitness for a particular purpose; ii) effectively excludes on behalf of all Contributors all liability for damages, including direct, indirect, special, incidental and consequential damages, such as lost profits; iii) states that any provisions which differ from this Agreement are offered by that Contributor alone and not by any other party; and iv) states that source code for the Program is available from such Contributor, and informs licensees how to obtain it in a reasonable manner on or through a medium customarily used for software exchange. When the Program is made available in source code form: a) it must be made available under this Agreement; and b) a copy of this Agreement must be included with each copy of the Program. Contributors may not remove or alter any copyright notices contained within the Program. Each Contributor must identify itself as the originator of its Contribution, if any, in a manner that reasonably allows subsequent Recipients to identify the originator of the Contribution. 4. COMMERCIAL DISTRIBUTION Commercial distributors of software may accept certain responsibilities with respect to end users, business partners and the like. While this license is intended to facilitate the commercial use of the Program, the Contributor who includes the Program in a commercial product offering should do so in a manner which does not create potential liability for other Contributors. Therefore, if a Contributor includes the Program in a commercial product offering, such Contributor ("Commercial Contributor") hereby agrees to defend and indemnify every other Contributor ("Indemnified Contributor") against any losses, damages and costs (collectively "Losses") arising from claims, lawsuits and other legal actions brought by a third party against the Indemnified Contributor to the extent caused by the acts or omissions of such Commercial Contributor in connection with its distribution of the Program in a commercial product offering. The obligations in this section do not apply to any claims or Losses relating to any actual or alleged intellectual property infringement. In order to qualify, an Indemnified Contributor must: a) promptly notify the Commercial Contributor in writing of such claim, and b) allow the Commercial Contributor to control, and cooperate with the Commercial Contributor in, the defense and any related settlement negotiations. The Indemnified Contributor may participate in any such claim at its own expense. For example, a Contributor might include the Program in a commercial product offering, Product X. That Contributor is then a Commercial Contributor. If that Commercial Contributor then makes performance claims, or offers warranties related to Product X, those performance claims and warranties are such Commercial Contributor's responsibility alone. Under this section, the Commercial Contributor would have to defend claims against the other Contributors related to those performance claims and warranties, and if a court requires any other Contributor to pay any damages as a result, the Commercial Contributor must pay those damages. 5. NO WARRANTY EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each Recipient is solely responsible for determining the appropriateness of using and distributing the Program and assumes all risks associated with its exercise of rights under this Agreement , including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and unavailability or interruption of operations. 6. DISCLAIMER OF LIABILITY EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR ANY CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 7. GENERAL If any provision of this Agreement is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this Agreement, and without further action by the parties hereto, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. If Recipient institutes patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Program itself (excluding combinations of the Program with other software or hardware) infringes such Recipient's patent(s), then such Recipient's rights granted under Section 2(b) shall terminate as of the date such litigation is filed. All Recipient's rights under this Agreement shall terminate if it fails to comply with any of the material terms or conditions of this Agreement and does not cure such failure in a reasonable period of time after becoming aware of such noncompliance. If all Recipient's rights under this Agreement terminate, Recipient agrees to cease use and distribution of the Program as soon as reasonably practicable. However, Recipient's obligations under this Agreement and any licenses granted by Recipient relating to the Program shall continue and survive. Everyone is permitted to copy and distribute copies of this Agreement, but in order to avoid inconsistency the Agreement is copyrighted and may only be modified in the following manner. The Agreement Steward reserves the right to publish new versions (including revisions) of this Agreement from time to time. No one other than the Agreement Steward has the right to modify this Agreement. The Eclipse Foundation is the initial Agreement Steward. The Eclipse Foundation may assign the responsibility to serve as the Agreement Steward to a suitable separate entity. Each new version of the Agreement will be given a distinguishing version number. The Program (including Contributions) may always be distributed subject to the version of the Agreement under which it was received. In addition, after a new version of the Agreement is published, Contributor may elect to distribute the Program (including its Contributions) under the new version. Except as expressly stated in Sections 2(a) and 2(b) above, Recipient receives no rights or licenses to the intellectual property of any Contributor under this Agreement, whether expressly, by implication, estoppel or otherwise. All rights in the Program not expressly granted under this Agreement are reserved. This Agreement is governed by the laws of the State of New York and the intellectual property laws of the United States of America. No party to this Agreement will bring a legal action under this Agreement more than one year after the cause of action arose. Each party waives its rights to a jury trial in any resulting litigation. asyncssh-1.11.1/MANIFEST.in000066400000000000000000000001701320320510200151460ustar00rootroot00000000000000include CONTRIBUTING.rst COPYRIGHT LICENSE README.rst pylintrc tox.ini include examples/*.py tests/*.py tests_py35/*.py asyncssh-1.11.1/README.rst000066400000000000000000000141301320320510200151000ustar00rootroot00000000000000AsyncSSH: Asynchronous SSH for Python ===================================== AsyncSSH is a Python package which provides an asynchronous client and server implementation of the SSHv2 protocol on top of the Python 3.4+ asyncio framework. .. code:: python import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: result = await conn.run('echo "Hello!"', check=True) print(result.stdout, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) Check out the `examples`__ to get started! __ http://asyncssh.readthedocs.io/en/stable/#client-examples Features -------- * Full support for SSHv2, SFTP, and SCP client and server functions * Shell, command, and subsystem channels * Environment variables, terminal type, and window size * Direct and forwarded TCP/IP channels * OpenSSH-compatible direct and forwarded UNIX domain socket channels * Local and remote TCP/IP port forwarding * Local and remote UNIX domain socket forwarding * X11 forwarding support on both the client and the server * SFTP protocol version 3 with OpenSSH extensions * SCP protocol support, including third-party remote to remote copies * Multiple simultaneous sessions on a single SSH connection * Multiple SSH connections in a single event loop * Byte and string based I/O with settable encoding * A variety of `key exchange`__, `encryption`__, and `MAC`__ algorithms * Support for `gzip compression`__ * Including OpenSSH variant to delay compression until after auth * Password, public key, and keyboard-interactive user authentication methods * Many types and formats of `public keys and certificates`__ * Including support for X.509 certificates as defined in RFC 6187 * Support for accessing keys managed by `ssh-agent`__ on UNIX systems * Including agent forwarding support on both the client and the server * Support for accessing keys managed by PuTTY's Pageant agent on Windows * OpenSSH-style `known_hosts file`__ support * OpenSSH-style `authorized_keys file`__ support * Compatibility with OpenSSH "Encrypt then MAC" option for better security * Time and byte-count based session key renegotiation * Designed to be easy to extend to support new forms of key exchange, authentication, encryption, and compression algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#key-exchange-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#encryption-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#mac-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#compression-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#public-key-support __ http://asyncssh.readthedocs.io/en/stable/api.html#ssh-agent-support __ http://asyncssh.readthedocs.io/en/stable/api.html#known-hosts __ http://asyncssh.readthedocs.io/en/stable/api.html#authorized-keys License ------- This package is released under the following terms: Copyright (c) 2013-2017 by Ron Frederick . All rights reserved. This program and the accompanying materials are made available under the terms of the **Eclipse Public License v1.0** which accompanies this distribution and is available at: http://www.eclipse.org/legal/epl-v10.html For more information about this license, please see the `Eclipse Public License FAQ `_. Prerequisites ------------- To use ``asyncssh``, you need the following: * Python 3.4 or later * cryptography (PyCA) 1.1 or later Installation ------------ Install AsyncSSH by running: :: pip install asyncssh Optional Extras ^^^^^^^^^^^^^^^ There are some optional modules you can install to enable additional functionality: * Install bcrypt from https://pypi.python.org/pypi/bcrypt if you want support for OpenSSH private key encryption. * Install gssapi from https://pypi.python.org/pypi/gssapi if you want support for GSSAPI key exchange and authentication on UNIX. * Install libsodium from https://github.com/jedisct1/libsodium and libnacl from https://pypi.python.org/pypi/libnacl if you want support for curve25519 Diffie Hellman key exchange, ed25519 keys, and the chacha20-poly1305 cipher. * Install libnettle from http://www.lysator.liu.se/~nisse/nettle/ if you want support for UMAC cryptographic hashes. * Install pyOpenSSL from https://pypi.python.org/pypi/pyOpenSSL if you want support for X.509 certificate authentication. * Install pypiwin32 from https://pypi.python.org/pypi/pypiwin32 if you want support for using the Pageant agent or support for GSSAPI key exchange and authentication on Windows. AsyncSSH defines the following optional PyPI extra packages to make it easy to install any or all of these dependencies: | bcrypt | gssapi | libnacl | pyOpenSSL | pypiwin32 For example, to install bcrypt, gssapi, libnacl, and pyOpenSSL on UNIX, you can run: :: pip install 'asyncssh[bcrypt,gssapi,libnacl,pyOpenSSL]' To install bcrypt, libnacl, pyOpenSSL, and pypiwin32 on Windows, you can run: :: pip install 'asyncssh[bcrypt,libnacl,pyOpenSSL,pypiwin32]' Note that you will still need to manually install the libsodium library listed above for libnacl to work correctly and/or libnettle for UMAC support. Unfortunately, since libsodium and libnettle are not Python packages, they cannot be directly installed using pip. Installing the development branch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ If you would like to install the development branch of asyncssh directly from Github, you can use the following command to do this: :: pip install git+https://github.com/ronf/asyncssh@develop Mailing Lists ------------- Three mailing lists are available for AsyncSSH: * `asyncssh-announce@googlegroups.com`__: Project announcements * `asyncssh-dev@googlegroups.com`__: Development discussions * `asyncssh-users@googlegroups.com`__: End-user discussions __ http://groups.google.com/d/forum/asyncssh-announce __ http://groups.google.com/d/forum/asyncssh-dev __ http://groups.google.com/d/forum/asyncssh-users asyncssh-1.11.1/appveyor.yml000066400000000000000000000013111320320510200157760ustar00rootroot00000000000000environment: matrix: - PYTHON: "C:\\Python34" ARCH: win32 - PYTHON: "C:\\Python34-x64" ARCH: win64 - PYTHON: "C:\\Python35" ARCH: win32 - PYTHON: "C:\\Python35-x64" ARCH: win64 - PYTHON: "C:\\Python36" ARCH: win32 - PYTHON: "C:\\Python36-x64" ARCH: win64 install: - "cd %PYTHON%" - "curl https://www.timeheart.net/appveyor/%ARCH%/libsodium-18.dll -O" - "curl https://www.timeheart.net/appveyor/%ARCH%/libnettle-6.dll -O" - "curl https://www.timeheart.net/appveyor/%ARCH%/libhogweed-4.dll -O" - "cd %APPVEYOR_BUILD_FOLDER%" - "%PYTHON%\\python.exe -m pip install tox" build: off test_script: - "%PYTHON%\\python.exe -m tox -e py" asyncssh-1.11.1/asyncssh/000077500000000000000000000000001320320510200152455ustar00rootroot00000000000000asyncssh-1.11.1/asyncssh/__init__.py000066400000000000000000000052101320320510200173540ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """An asynchronous SSH2 library for Python""" from .version import __author__, __author_email__, __url__, __version__ # pylint: disable=wildcard-import from .constants import * # pylint: enable=wildcard-import from .agent import SSHAgentClient, SSHAgentKeyPair, connect_agent from .auth_keys import SSHAuthorizedKeys from .auth_keys import import_authorized_keys, read_authorized_keys from .channel import SSHClientChannel, SSHServerChannel from .channel import SSHTCPChannel, SSHUNIXChannel from .client import SSHClient from .connection import SSHClientConnection, SSHServerConnection from .connection import create_connection, create_server, connect, listen from .editor import SSHLineEditorChannel from .known_hosts import SSHKnownHosts from .known_hosts import import_known_hosts, read_known_hosts from .known_hosts import match_known_hosts from .listener import SSHListener from .logging import logger from .misc import Error, DisconnectError, ChannelOpenError from .misc import PasswordChangeRequired from .misc import BreakReceived, SignalReceived, TerminalSizeChanged from .pbe import KeyEncryptionError from .process import SSHClientProcess, SSHServerProcess from .process import SSHCompletedProcess, ProcessError from .process import DEVNULL, PIPE, STDOUT from .public_key import SSHKey, SSHKeyPair, SSHCertificate from .public_key import KeyGenerationError, KeyImportError, KeyExportError from .public_key import generate_private_key, import_private_key from .public_key import import_public_key, import_certificate from .public_key import read_private_key, read_public_key, read_certificate from .public_key import read_private_key_list, read_public_key_list from .public_key import read_certificate_list from .public_key import load_keypairs, load_public_keys, load_certificates from .scp import scp from .session import SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession from .server import SSHServer from .sftp import SFTPClient, SFTPClientFile, SFTPServer, SFTPError from .sftp import SFTPAttrs, SFTPVFSAttrs, SFTPName from .sftp import SEEK_SET, SEEK_CUR, SEEK_END from .stream import SSHReader, SSHWriter # Import these explicitly to trigger register calls in them from . import ed25519, ecdsa, rsa, dsa, ecdh, dh asyncssh-1.11.1/asyncssh/agent.py000066400000000000000000000451141320320510200167220ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH agent client""" import asyncio import errno import os import sys import tempfile import asyncssh from .logging import logger try: if sys.platform == 'win32': # pragma: no cover from .agent_win32 import open_agent else: from .agent_unix import open_agent except ImportError as exc: # pragma: no cover def open_agent(loop, agent_path, reason=str(exc)): """Dummy function if we're unable to import agent support""" # pylint: disable=unused-argument raise OSError(errno.ENOENT, 'Agent support unavailable: %s' % reason) from .listener import create_unix_forward_listener from .misc import ChannelOpenError, load_default_keypairs from .packet import Byte, String, UInt32, PacketDecodeError, SSHPacket from .public_key import SSHKeyPair # pylint: disable=bad-whitespace # Client request message numbers SSH_AGENTC_REQUEST_IDENTITIES = 11 SSH_AGENTC_SIGN_REQUEST = 13 SSH_AGENTC_ADD_IDENTITY = 17 SSH_AGENTC_REMOVE_IDENTITY = 18 SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19 SSH_AGENTC_ADD_SMARTCARD_KEY = 20 SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21 SSH_AGENTC_LOCK = 22 SSH_AGENTC_UNLOCK = 23 SSH_AGENTC_ADD_ID_CONSTRAINED = 25 SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26 SSH_AGENTC_EXTENSION = 27 # Agent response message numbers SSH_AGENT_FAILURE = 5 SSH_AGENT_SUCCESS = 6 SSH_AGENT_IDENTITIES_ANSWER = 12 SSH_AGENT_SIGN_RESPONSE = 14 SSH_AGENT_EXTENSION_FAILURE = 28 # SSH agent constraint numbers SSH_AGENT_CONSTRAIN_LIFETIME = 1 SSH_AGENT_CONSTRAIN_CONFIRM = 2 SSH_AGENT_CONSTRAIN_EXTENSION = 3 # SSH agent signature flags SSH_AGENT_RSA_SHA2_256 = 2 SSH_AGENT_RSA_SHA2_512 = 4 # pylint: enable=bad-whitespace class _X11AgentListener: """Listener used to forward agent connections""" def __init__(self, tempdir, path, unix_listener): self._tempdir = tempdir self._path = path self._unix_listener = unix_listener def get_path(self): """Return the path being listened on""" return self._path def close(self): """Close the agent listener""" self._unix_listener.close() self._tempdir.cleanup() class SSHAgentKeyPair(SSHKeyPair): """Surrogate for a key managed by the SSH agent""" _key_type = 'agent' def __init__(self, agent, algorithm, public_data, comment): super().__init__(algorithm, comment) self._agent = agent self.public_data = public_data self._cert = algorithm.endswith(b'-cert-v01@openssh.com') self._flags = 0 if self._cert: self.sig_algorithm = algorithm[:-21] else: self.sig_algorithm = algorithm if self.sig_algorithm == b'ssh-rsa': self.sig_algorithms = (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa') else: self.sig_algorithms = (self.sig_algorithm,) if self._cert: self.host_key_algorithms = (algorithm,) else: self.host_key_algorithms = self.sig_algorithms def set_sig_algorithm(self, sig_algorithm): """Set the signature algorithm to use when signing data""" self.sig_algorithm = sig_algorithm if not self._cert: self.algorithm = sig_algorithm if sig_algorithm == b'rsa-sha2-256': self._flags |= SSH_AGENT_RSA_SHA2_256 elif sig_algorithm == b'rsa-sha2-512': self._flags |= SSH_AGENT_RSA_SHA2_512 @asyncio.coroutine def sign(self, data): """Sign a block of data with this private key""" return (yield from self._agent.sign(self.public_data, data, self._flags)) @asyncio.coroutine def remove(self): """Remove this key pair from the agent""" yield from self._agent.remove_keys([self]) class SSHAgentClient: """SSH agent client""" def __init__(self, loop, agent_path): self._loop = loop self._agent_path = agent_path self._reader = None self._writer = None self._lock = asyncio.Lock(loop=loop) def _cleanup(self): """Clean up this SSH agent client""" if self._writer: self._writer.close() self._reader = None self._writer = None @staticmethod def encode_constraints(lifetime, confirm): """Encode key constraints""" result = b'' if lifetime: result += Byte(SSH_AGENT_CONSTRAIN_LIFETIME) + UInt32(lifetime) if confirm: result += Byte(SSH_AGENT_CONSTRAIN_CONFIRM) return result @asyncio.coroutine def connect(self): """Connect to the SSH agent""" if isinstance(self._agent_path, asyncssh.SSHServerConnection): self._reader, self._writer = \ yield from self._agent_path.open_agent_connection() else: self._reader, self._writer = \ yield from open_agent(self._loop, self._agent_path) @asyncio.coroutine def _make_request(self, msgtype, *args): """Send an SSH agent request""" with (yield from self._lock): try: if not self._writer: yield from self.connect() payload = Byte(msgtype) + b''.join(args) self._writer.write(UInt32(len(payload)) + payload) resplen = yield from self._reader.readexactly(4) resplen = int.from_bytes(resplen, 'big') resp = yield from self._reader.readexactly(resplen) resp = SSHPacket(resp) resptype = resp.get_byte() return resptype, resp except (OSError, EOFError, PacketDecodeError) as exc: self._cleanup() raise ValueError(str(exc)) from None @asyncio.coroutine def get_keys(self): """Request the available client keys This method is a coroutine which returns a list of client keys available in the ssh-agent. :returns: A list of :class:`SSHKeyPair` objects """ resptype, resp = \ yield from self._make_request(SSH_AGENTC_REQUEST_IDENTITIES) if resptype == SSH_AGENT_IDENTITIES_ANSWER: result = [] num_keys = resp.get_uint32() for _ in range(num_keys): key_blob = resp.get_string() comment = resp.get_string() packet = SSHPacket(key_blob) algorithm = packet.get_string() result.append(SSHAgentKeyPair(self, algorithm, key_blob, comment)) resp.check_end() return result else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def sign(self, key_blob, data, flags=0): """Sign a block of data with the requested key""" resptype, resp = \ yield from self._make_request(SSH_AGENTC_SIGN_REQUEST, String(key_blob), String(data), UInt32(flags)) if resptype == SSH_AGENT_SIGN_RESPONSE: sig = resp.get_string() resp.check_end() return sig elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to sign with requested key') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def add_keys(self, keylist=(), passphrase=None, lifetime=None, confirm=False): """Add keys to the agent This method adds a list of local private keys and optional matching certificates to the agent. :param keylist: (optional) The list of keys to add. If not specified, an attempt will be made to load keys from the files :file:`.ssh/id_ed25519`, :file:`.ssh/id_ecdsa`, :file:`.ssh/id_rsa` and :file:`.ssh/id_dsa` in the user's home directory with optional matching certificates loaded from the files :file:`.ssh/id_ed25519-cert.pub`, :file:`.ssh/id_ecdsa-cert.pub`, :file:`.ssh/id_rsa-cert.pub`, and :file:`.ssh/id_dsa-cert.pub`. :param str passphrase: (optional) The passphrase to use to decrypt the keys. :param lifetime: (optional) The time in seconds after which the keys should be automatically deleted, or ``None`` to store these keys indefinitely (the default). :param bool confirm: (optional) Whether or not to require confirmation for each private key operation which uses these keys, defaulting to ``False``. :type keylist: *see* :ref:`SpecifyingPrivateKeys` :type lifetime: `int` or ``None`` :raises: :exc:`ValueError` if the keys cannot be added """ if keylist: keypairs = asyncssh.load_keypairs(keylist, passphrase) else: keypairs = load_default_keypairs(passphrase) constraints = self.encode_constraints(lifetime, confirm) msgtype = SSH_AGENTC_ADD_ID_CONSTRAINED if constraints else \ SSH_AGENTC_ADD_IDENTITY for keypair in keypairs: comment = keypair.get_comment() resptype, resp = \ yield from self._make_request(msgtype, keypair.get_agent_private_key(), String(comment or ''), constraints) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to add key') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def add_smartcard_keys(self, provider, pin=None, lifetime=None, confirm=False): """Store keys associated with a smart card in the agent :param str provider: The name of the smart card provider :param pin: (optional) The PIN to use to unlock the smart card :param lifetime: (optional) The time in seconds after which the keys should be automatically deleted, or ``None`` to store these keys indefinitely (the default). :param bool confirm: (optional) Whether or not to require confirmation for each private key operation which uses these keys, defaulting to ``False``. :type pin: `str` or ``None`` :type lifetime: `int` or ``None`` :raises: :exc:`ValueError` if the keys cannot be added """ constraints = self.encode_constraints(lifetime, confirm) msgtype = SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED \ if constraints else SSH_AGENTC_ADD_SMARTCARD_KEY resptype, resp = \ yield from self._make_request(msgtype, String(provider), String(pin or ''), constraints) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to add keys') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def remove_keys(self, keylist): """Remove a key stored in the agent :param keylist: The list of keys to remove. :type keylist: list of :class:`SSHKeyPair` :raises: :exc:`ValueError` if any keys are not found """ for keypair in keylist: resptype, resp = \ yield from self._make_request(SSH_AGENTC_REMOVE_IDENTITY, String(keypair.public_data)) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Key not found') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def remove_smartcard_keys(self, provider, pin=None): """Remove keys associated with a smart card stored in the agent :param str provider: The name of the smart card provider :param pin: (optional) The PIN to use to unlock the smart card :type pin: `str` or ``None`` :raises: :exc:`ValueError` if the keys are not found """ resptype, resp = \ yield from self._make_request(SSH_AGENTC_REMOVE_SMARTCARD_KEY, String(provider), String(pin or '')) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Keys not found') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def remove_all(self): """Remove all keys stored in the agent :raises: :exc:`ValueError` if the keys can't be removed """ resptype, resp = \ yield from self._make_request(SSH_AGENTC_REMOVE_ALL_IDENTITIES) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to remove all keys') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def lock(self, passphrase): """Lock the agent using the specified passphrase :param str passphrase: The passphrase required to later unlock the agent :raises: :exc:`ValueError` if the agent can't be locked """ resptype, resp = yield from self._make_request(SSH_AGENTC_LOCK, String(passphrase)) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to lock SSH agent') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def unlock(self, passphrase): """Unlock the agent using the specified passphrase :param str passphrase: The passphrase to use to unlock the agent :raises: :exc:`ValueError` if the agent can't be unlocked """ resptype, resp = yield from self._make_request(SSH_AGENTC_UNLOCK, String(passphrase)) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to unlock SSH agent') else: raise ValueError('Unknown SSH agent response: %d' % resptype) @asyncio.coroutine def query_extensions(self): """Return a list of extensions supported by the agent :returns: A list of strings of supported extension names """ resptype, resp = yield from self._make_request(SSH_AGENTC_EXTENSION, String('query')) if resptype == SSH_AGENT_SUCCESS: result = [] while resp: exttype = resp.get_string() try: exttype = exttype.decode('utf-8') except UnicodeDecodeError: raise ValueError('Invalid extension type name') result.append(exttype) return result elif resptype == SSH_AGENT_FAILURE: return [] else: raise ValueError('Unknown SSH agent response: %d' % resptype) def close(self): """Close the SSH agent connection This method closes the connection to the ssh-agent. Any attempts to use this :class:`SSHAgentClient` or the key pairs it previously returned will result in an error. """ self._cleanup() @asyncio.coroutine def connect_agent(agent_path=None, *, loop=None): """Make a connection to the SSH agent This function attempts to connect to an ssh-agent process listening on a UNIX domain socket at ``agent_path``. If not provided, it will attempt to get the path from the ``SSH_AUTH_SOCK`` environment variable. If the connection is successful, an ``SSHAgentClient`` object is returned that has methods on it you can use to query the ssh-agent. If no path is specified and the environment variable is not set or the connection to the agent fails, this function returns ``None``. :param agent_path: (optional) The path to use to contact the ssh-agent process, or the :class:`SSHServerConnection` to forward the agent request over. :param loop: (optional) The event loop to use when creating the connection. If not specified, the default event loop is used. :type agent_path: str or :class:`SSHServerConnection` :returns: An :class:`SSHAgentClient` or ``None`` """ agent = SSHAgentClient(loop, agent_path) try: yield from agent.connect() return agent except (OSError, ChannelOpenError) as exc: logger.debug('Unable to contact agent: %s', exc) return None @asyncio.coroutine def create_agent_listener(conn, loop): """Create a listener for forwarding ssh-agent connections""" try: tempdir = tempfile.TemporaryDirectory(prefix='asyncssh-') path = os.path.join(tempdir.name, 'agent') unix_listener = yield from create_unix_forward_listener( conn, loop, conn.create_agent_connection, path) return _X11AgentListener(tempdir, path, unix_listener) except OSError: return None asyncssh-1.11.1/asyncssh/agent_unix.py000066400000000000000000000015561320320510200177670ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH agent support code for UNIX""" import asyncio import errno import os @asyncio.coroutine def open_agent(loop, agent_path): """Open a connection to ssh-agent""" if not loop: loop = asyncio.get_event_loop() if not agent_path: agent_path = os.environ.get('SSH_AUTH_SOCK', None) if not agent_path: raise OSError(errno.ENOENT, 'Agent not found') return (yield from asyncio.open_unix_connection(agent_path, loop=loop)) asyncssh-1.11.1/asyncssh/agent_win32.py000066400000000000000000000060561320320510200177460ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH agent support code for Windows""" # Some of the imports below won't be found when running pylint on UNIX # pylint: disable=import-error import asyncio import ctypes import ctypes.wintypes import errno import mmapfile import win32api import win32con import win32ui _AGENT_COPYDATA_ID = 0x804e50ba _AGENT_MAX_MSGLEN = 8192 _AGENT_NAME = 'Pageant' def _find_agent_window(): """Find and return the Pageant window""" try: return win32ui.FindWindow(_AGENT_NAME, _AGENT_NAME) except win32ui.error: raise OSError(errno.ENOENT, 'Agent not found') from None class _CopyDataStruct(ctypes.Structure): """Windows COPYDATASTRUCT argument for WM_COPYDATA message""" _fields_ = (('dwData', ctypes.wintypes.LPARAM), ('cbData', ctypes.wintypes.DWORD), ('lpData', ctypes.c_char_p)) class _PageantTransport: """Transport to connect to Pageant agent on Windows""" def __init__(self): self._mapname = '%s%08x' % (_AGENT_NAME, win32api.GetCurrentThreadId()) try: self._mapfile = mmapfile.mmapfile(None, self._mapname, _AGENT_MAX_MSGLEN, 0, 0) except mmapfile.error as exc: raise OSError(errno.EIO, str(exc)) from None self._cds = _CopyDataStruct(_AGENT_COPYDATA_ID, len(self._mapname) + 1, self._mapname.encode()) self._writing = False def write(self, data): """Write request data to Pageant agent""" if not self._writing: self._mapfile.seek(0) self._writing = True try: self._mapfile.write(data) except ValueError as exc: raise OSError(errno.EIO, str(exc)) from None @asyncio.coroutine def readexactly(self, n): """Read response data from Pageant agent""" if self._writing: cwnd = _find_agent_window() if not cwnd.SendMessage(win32con.WM_COPYDATA, None, self._cds): raise OSError(errno.EIO, 'Unable to send agent request') self._writing = False self._mapfile.seek(0) result = self._mapfile.read(n) if len(result) != n: raise asyncio.IncompleteReadError(result, n) return result def close(self): """Close the connection to Pageant""" if self._mapfile: self._mapfile.close() self._mapfile = None @asyncio.coroutine def open_agent(loop, agent_path): """Open a connection to the Pageant agent""" # pylint: disable=unused-argument _find_agent_window() transport = _PageantTransport() return transport, transport asyncssh-1.11.1/asyncssh/asn1.py000066400000000000000000000522641320320510200164720ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Utilities for encoding and decoding ASN.1 DER data The der_encode function takes a Python value and encodes it in DER format, returning a byte string. In addition to supporting standard Python types, BitString can be used to encode a DER bit string, ObjectIdentifier can be used to encode OIDs, values can be wrapped in a TaggedDERObject to set an alternate DER tag on them, and non-standard types can be encoded by placing them in a RawDERObject. The der_decode function takes a byte string in DER format and decodes it into the corresponding Python values. """ # pylint: disable=bad-whitespace # ASN.1 object classes UNIVERSAL = 0x00 APPLICATION = 0x01 CONTEXT_SPECIFIC = 0x02 PRIVATE = 0x03 # ASN.1 universal object tags END_OF_CONTENT = 0x00 BOOLEAN = 0x01 INTEGER = 0x02 BIT_STRING = 0x03 OCTET_STRING = 0x04 NULL = 0x05 OBJECT_IDENTIFIER = 0x06 UTF8_STRING = 0x0c SEQUENCE = 0x10 SET = 0x11 IA5_STRING = 0x16 # pylint: enable=bad-whitespace _asn1_class = ('Universal', 'Application', 'Context-specific', 'Private') _der_class_by_tag = {} _der_class_by_type = {} def _encode_identifier(asn1_class, constructed, tag): """Encode a DER object's identifier""" if asn1_class not in (UNIVERSAL, APPLICATION, CONTEXT_SPECIFIC, PRIVATE): raise ASN1EncodeError('Invalid ASN.1 class') flags = (asn1_class << 6) | (0x20 if constructed else 0x00) if tag < 0x20: identifier = [flags | tag] else: identifier = [tag & 0x7f] while tag >= 0x80: tag >>= 7 identifier.append(0x80 | (tag & 0x7f)) identifier.append(flags | 0x1f) return bytes(identifier[::-1]) class ASN1Error(ValueError): """ASN.1 coding error""" class ASN1EncodeError(ASN1Error): """ASN.1 DER encoding error""" class ASN1DecodeError(ASN1Error): """ASN.1 DER decoding error""" class DERTag: """A decorator used by classes which convert values to/from DER Classes which convert Python values to and from DER format should use the DERTag decorator to indicate what DER tag value they understand. When DER data is decoded, the tag is looked up in the list to see which class to call to perform the decoding. Classes which convert existing Python types to and from DER format can specify the list of types they understand in the optional "types" argument. Otherwise, conversion is expected to be to and from the new class being defined. """ def __init__(self, tag, types=(), constructed=False): self._tag = tag self._types = types self._identifier = _encode_identifier(UNIVERSAL, constructed, tag) def __call__(self, cls): cls.identifier = self._identifier _der_class_by_tag[self._tag] = cls if self._types: for t in self._types: _der_class_by_type[t] = cls else: _der_class_by_type[cls] = cls return cls class RawDERObject: """A class which can encode a DER object of an arbitrary type This object is initialized with an ASN.1 class, tag, and a byte string representing the already encoded data. Such objects will never have the constructed flag set, since that is represented here as a TaggedDERObject. """ def __init__(self, tag, content, asn1_class): self.asn1_class = asn1_class self.tag = tag self.content = content def __repr__(self): return ('RawDERObject(%s, %s, %r)' % (_asn1_class[self.asn1_class], self.tag, self.content)) def __eq__(self, other): return (isinstance(other, type(self)) and self.asn1_class == other.asn1_class and self.tag == other.tag and self.content == other.content) def __hash__(self): return hash((self.asn1_class, self.tag, self.content)) def encode_identifier(self): """Encode the DER identifier for this object as a byte string""" return _encode_identifier(self.asn1_class, False, self.tag) def encode(self): """Encode the content for this object as a DER byte string""" return self.content class TaggedDERObject: """An explicitly tagged DER object This object provides a way to wrap an ASN.1 object with an explicit tag. The value (including the tag representing its actual type) is then encoded as part of its value. By default, the ASN.1 class for these objects is CONTEXT_SPECIFIC, and the DER encoding always marks these values as constructed. """ def __init__(self, tag, value, asn1_class=CONTEXT_SPECIFIC): self.asn1_class = asn1_class self.tag = tag self.value = value def __repr__(self): if self.asn1_class == CONTEXT_SPECIFIC: return 'TaggedDERObject(%s, %r)' % (self.tag, self.value) else: return ('TaggedDERObject(%s, %s, %r)' % (_asn1_class[self.asn1_class], self.tag, self.value)) def __eq__(self, other): return (isinstance(other, type(self)) and self.asn1_class == other.asn1_class and self.tag == other.tag and self.value == other.value) def __hash__(self): return hash((self.asn1_class, self.tag, self.value)) def encode_identifier(self): """Encode the DER identifier for this object as a byte string""" return _encode_identifier(self.asn1_class, True, self.tag) def encode(self): """Encode the content for this object as a DER byte string""" return der_encode(self.value) @DERTag(NULL, (type(None),)) class _Null: """A null value""" @staticmethod def encode(value): """Encode a DER null value""" # pylint: disable=unused-argument return b'' @classmethod def decode(cls, constructed, content): """Decode a DER null value""" if constructed: raise ASN1DecodeError('NULL should not be constructed') if content: raise ASN1DecodeError('NULL should not have associated content') return None @DERTag(BOOLEAN, (bool,)) class _Boolean: """A boolean value""" @staticmethod def encode(value): """Encode a DER boolean value""" return b'\xff' if value else b'\0' @classmethod def decode(cls, constructed, content): """Decode a DER boolean value""" if constructed: raise ASN1DecodeError('BOOLEAN should not be constructed') if content not in {b'\x00', b'\xff'}: raise ASN1DecodeError('BOOLEAN content must be 0x00 or 0xff') return bool(content[0]) @DERTag(INTEGER, (int,)) class _Integer: """An integer value""" @staticmethod def encode(value): """Encode a DER integer value""" l = value.bit_length() l = l // 8 + 1 if l % 8 == 0 else (l + 7) // 8 result = value.to_bytes(l, 'big', signed=True) return result[1:] if result.startswith(b'\xff\x80') else result @classmethod def decode(cls, constructed, content): """Decode a DER integer value""" if constructed: raise ASN1DecodeError('INTEGER should not be constructed') return int.from_bytes(content, 'big', signed=True) @DERTag(OCTET_STRING, (bytes, bytearray)) class _OctetString: """An octet string value""" @staticmethod def encode(value): """Encode a DER octet string""" return value @classmethod def decode(cls, constructed, content): """Decode a DER octet string""" if constructed: raise ASN1DecodeError('OCTET STRING should not be constructed') return content @DERTag(UTF8_STRING, (str,)) class _UTF8String: """A UTF-8 string value""" @staticmethod def encode(value): """Encode a DER UTF-8 string""" return value.encode('utf-8') @classmethod def decode(cls, constructed, content): """Decode a DER UTF-8 string""" if constructed: raise ASN1DecodeError('UTF8 STRING should not be constructed') return content.decode('utf-8') @DERTag(SEQUENCE, (list, tuple), constructed=True) class _Sequence: """A sequence of values""" @staticmethod def encode(value): """Encode a sequence of DER values""" return b''.join(der_encode(item) for item in value) @classmethod def decode(cls, constructed, content): """Decode a sequence of DER values""" if not constructed: raise ASN1DecodeError('SEQUENCE should always be constructed') offset = 0 length = len(content) value = [] while offset < length: # pylint: disable=unpacking-non-sequence item, consumed = der_decode(content[offset:], partial_ok=True) # pylint: enable=unpacking-non-sequence value.append(item) offset += consumed return tuple(value) @DERTag(SET, (set, frozenset), constructed=True) class _Set: """A set of DER values""" @staticmethod def encode(value): """Encode a set of DER values""" return b''.join(sorted(der_encode(item) for item in value)) @classmethod def decode(cls, constructed, content): """Decode a set of DER values""" if not constructed: raise ASN1DecodeError('SET should always be constructed') offset = 0 length = len(content) value = set() while offset < length: # pylint: disable=unpacking-non-sequence item, consumed = der_decode(content[offset:], partial_ok=True) # pylint: enable=unpacking-non-sequence value.add(item) offset += consumed return frozenset(value) @DERTag(BIT_STRING) class BitString: """A string of bits This object can be initialized either with a byte string and an optional count of the number of least-significant bits in the last byte which should not be included in the value, or with a string consisting only of the digits '0' and '1'. An optional 'named' flag can also be set, indicating that the BitString was specified with named bits, indicating that the proper DER encoding of it should strip any trailing zeroes. """ def __init__(self, value, unused=0, named=False): if unused < 0 or unused > 7: raise ASN1EncodeError('Unused bit count must be between 0 and 7') if isinstance(value, bytes): if unused: if not value: raise ASN1EncodeError('Can\'t have unused bits with empty ' 'value') elif value[-1] & ((1 << unused) - 1): raise ASN1EncodeError('Unused bits in value should be ' 'zero') elif isinstance(value, str): if unused: raise ASN1EncodeError('Unused bit count should not be set ' 'when providing a string') used = len(value) % 8 unused = 8 - used if used else 0 value += unused * '0' value = bytes(int(value[i:i+8], 2) for i in range(0, len(value), 8)) else: raise ASN1EncodeError('Unexpected type of bit string value') if named: while value and not value[-1] & (1 << unused): unused += 1 if unused == 8: value = value[:-1] unused = 0 self.value = value self.unused = unused def __str__(self): result = ''.join(bin(b)[2:].zfill(8) for b in self.value) if self.unused: result = result[:-self.unused] return result def __repr__(self): return "BitString('%s')" % self def __eq__(self, other): return (isinstance(other, type(self)) and self.value == other.value and self.unused == other.unused) def __hash__(self): return hash((self.value, self.unused)) def encode(self): """Encode a DER bit string""" return bytes((self.unused,)) + self.value @classmethod def decode(cls, constructed, content): """Decode a DER bit string""" if constructed: raise ASN1DecodeError('BIT STRING should not be constructed') if not content or content[0] > 7: raise ASN1DecodeError('Invalid unused bit count') return cls(content[1:], unused=content[0]) @DERTag(IA5_STRING) class IA5String: """An ASCII string value""" def __init__(self, value): self.value = value def __str__(self): return self.value def __repr__(self): return "IA5String('%s')" % self.value def __eq__(self, other): return isinstance(other, type(self)) and self.value == other.value def __hash__(self): return hash(self.value) def encode(self): """Encode a DER IA5 string""" # ASN.1 defines this type as only containing ASCII characters, but # some tools expecting ASN.1 allow IA5Strings to contain UTF-8 # characters, so we leave it up to the caller whether to resrict # the data to plain ASCII or not. if isinstance(self.value, str): return self.value.encode('utf-8') else: return self.value @classmethod def decode(cls, constructed, content): """Decode a DER IA5 string""" if constructed: raise ASN1DecodeError('IA5 STRING should not be constructed') return cls(content.decode('utf-8')) @DERTag(OBJECT_IDENTIFIER) class ObjectIdentifier: """An object identifier (OID) value This object can be initialized from a string of dot-separated integer values, representing a hierarchical namespace. All OIDs show have at least two components, with the first being between 0 and 2 (indicating ITU-T, ISO, or joint assignment). In cases where the first component is 0 or 1, the second component must be in the range 0 to 39 due to the way these first two components are encoded. """ def __init__(self, value): self.value = value def __str__(self): return self.value def __repr__(self): return "ObjectIdentifier('%s')" % self.value def __eq__(self, other): return isinstance(other, type(self)) and self.value == other.value def __hash__(self): return hash(self.value) def encode(self): """Encode a DER object identifier""" def _bytes(component): """Convert a single element of an OID to a DER byte string""" if component < 0: raise ASN1EncodeError('Components of object identifier must ' 'be greater than or equal to 0') result = [component & 0x7f] while component >= 0x80: component >>= 7 result.append(0x80 | (component & 0x7f)) return bytes(result[::-1]) try: components = [int(c) for c in self.value.split('.')] except ValueError: raise ASN1EncodeError('Component values must be integers') if len(components) < 2: raise ASN1EncodeError('Object identifiers must have at least two ' 'components') elif components[0] < 0 or components[0] > 2: raise ASN1EncodeError('First component of object identifier must ' 'be between 0 and 2') elif components[0] < 2 and (components[1] < 0 or components[1] > 39): raise ASN1EncodeError('Second component of object identifier must ' 'be between 0 and 39') components[0:2] = [components[0]*40 + components[1]] return b''.join(_bytes(c) for c in components) @classmethod def decode(cls, constructed, content): """Decode a DER object identifier""" if constructed: raise ASN1DecodeError('OBJECT IDENTIFIER should not be ' 'constructed') if not content: raise ASN1DecodeError('Empty object identifier') b = content[0] components = list(divmod(b, 40)) if b < 80 else [2, b-80] component = 0 for b in content[1:]: if b == 0x80 and component == 0: raise ASN1DecodeError('Invalid component') elif b < 0x80: components.append(component | b) component = 0 else: component |= b & 0x7f component <<= 7 if component: raise ASN1DecodeError('Incomplete component') return cls('.'.join(str(c) for c in components)) def der_encode(value): """Encode a value in DER format This function takes a Python value and encodes it in DER format. The following mapping of types is used: NoneType -> NULL bool -> BOOLEAN int -> INTEGER bytes, bytearray -> OCTET STRING str -> UTF8 STRING list, tuple -> SEQUENCE set, frozenset -> SET BitString -> BIT STRING ObjectIdentifier -> OBJECT IDENTIFIER An explicitly tagged DER object can be encoded by passing in a TaggedDERObject which specifies the ASN.1 class, tag, and value to encode. Other types can be encoded by passing in a RawDERObject which specifies the ASN.1 class, tag, and raw content octets to encode. """ t = type(value) if t in (RawDERObject, TaggedDERObject): identifier = value.encode_identifier() content = value.encode() elif t in _der_class_by_type: cls = _der_class_by_type[t] identifier = cls.identifier content = cls.encode(value) else: raise ASN1EncodeError('Cannot DER encode type %s' % t.__name__) length = len(content) if length < 0x80: len_bytes = bytes((length,)) else: len_bytes = length.to_bytes((length.bit_length() + 7) // 8, 'big') len_bytes = bytes((0x80 | len(len_bytes),)) + len_bytes return identifier + len_bytes + content def der_decode(data, partial_ok=False): """Decode a value in DER format This function takes a byte string in DER format and converts it to a corresponding set of Python objects. The following mapping of ASN.1 tags to Python types is used: NULL -> NoneType BOOLEAN -> bool INTEGER -> int OCTET STRING -> bytes UTF8 STRING -> str SEQUENCE -> tuple SET -> frozenset BIT_STRING -> BitString OBJECT IDENTIFIER -> ObjectIdentifier Explicitly tagged objects are returned as type TaggedDERObject, with fields holding the object class, tag, and tagged value. Other object tags are returned as type RawDERObject, with fields holding the object class, tag, and raw content octets. If partial_ok is True, this function returns a tuple of the decoded value and number of bytes consumed. Otherwise, all data bytes must be consumed and only the decoded value is returned. """ if len(data) < 2: raise ASN1DecodeError('Incomplete data') tag = data[0] asn1_class, constructed, tag = tag >> 6, bool(tag & 0x20), tag & 0x1f offset = 1 if tag == 0x1f: tag = 0 for b in data[offset:]: offset += 1 if b < 0x80: tag |= b break else: tag |= b & 0x7f tag <<= 7 else: raise ASN1DecodeError('Incomplete tag') if offset >= len(data): raise ASN1DecodeError('Incomplete data') length = data[offset] offset += 1 if length > 0x80: len_size = length & 0x7f length = int.from_bytes(data[offset:offset+len_size], 'big') offset += len_size elif length == 0x80: raise ASN1DecodeError('Indefinite length not allowed') if offset+length > len(data): raise ASN1DecodeError('Incomplete data') if not partial_ok and offset+length < len(data): raise ASN1DecodeError('Data contains unexpected bytes at end') if asn1_class == UNIVERSAL and tag in _der_class_by_tag: cls = _der_class_by_tag[tag] value = cls.decode(constructed, data[offset:offset+length]) elif constructed: value = TaggedDERObject(tag, der_decode(data[offset:offset+length]), asn1_class) else: value = RawDERObject(tag, data[offset:offset+length], asn1_class) if partial_ok: return value, offset+length else: return value asyncssh-1.11.1/asyncssh/auth.py000066400000000000000000000642741320320510200165750ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH authentication handlers""" import asyncio from .constants import DEFAULT_LANG, DISC_PROTOCOL_ERROR from .gss import GSSError from .logging import logger from .misc import DisconnectError, PasswordChangeRequired from .packet import Boolean, Byte, String, UInt32, SSHPacketHandler from .saslprep import saslprep, SASLPrepError # pylint: disable=bad-whitespace # SSH message values for GSS auth MSG_USERAUTH_GSSAPI_RESPONSE = 60 MSG_USERAUTH_GSSAPI_TOKEN = 61 MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = 63 MSG_USERAUTH_GSSAPI_ERROR = 64 MSG_USERAUTH_GSSAPI_ERRTOK = 65 MSG_USERAUTH_GSSAPI_MIC = 66 # SSH message values for public key auth MSG_USERAUTH_PK_OK = 60 # SSH message values for password auth MSG_USERAUTH_PASSWD_CHANGEREQ = 60 # SSH message values for 'keyboard-interactive' auth MSG_USERAUTH_INFO_REQUEST = 60 MSG_USERAUTH_INFO_RESPONSE = 61 # pylint: enable=bad-whitespace _auth_methods = [] _client_auth_handlers = {} _server_auth_handlers = {} class _Auth(SSHPacketHandler): """Parent class for authentication""" def __init__(self, conn, coro): self._conn = conn self._coro = conn.create_task(coro) def create_task(self, coro): """Create an asynchronous auth task""" self.cancel() self._coro = self._conn.create_task(coro) def cancel(self): """Cancel any authentication in progress""" if self._coro: # pragma: no branch self._coro.cancel() self._coro = None class _ClientAuth(_Auth): """Parent class for client authentication""" def __init__(self, conn, method): self._method = method super().__init__(conn, self._start()) @asyncio.coroutine def _start(self): """Abstract method for starting client authentication""" # Provided by subclass raise NotImplementedError def auth_succeeded(self): """Callback when auth succeeds""" def auth_failed(self): """Callback when auth fails""" @asyncio.coroutine def send_request(self, *args, key=None): """Send a user authentication request""" yield from self._conn.send_userauth_request(self._method, *args, key=key) class _ClientNullAuth(_ClientAuth): """Client side implementation of null auth""" @asyncio.coroutine def _start(self): """Start client null authentication""" yield from self.send_request() class _ClientGSSKexAuth(_ClientAuth): """Client side implementation of GSS key exchange auth""" @asyncio.coroutine def _start(self): """Start client GSS key exchange authentication""" if self._conn.gss_kex_auth_requested(): yield from self.send_request(key=self._conn.get_gss_context()) else: self._conn.try_next_auth() class _ClientGSSMICAuth(_ClientAuth): """Client side implementation of GSS MIC auth""" def __init__(self, conn, method): super().__init__(conn, method) self._gss = None self._got_error = False @asyncio.coroutine def _start(self): """Start client GSS MIC authentication""" if self._conn.gss_mic_auth_requested(): self._gss = self._conn.get_gss_context() mechs = b''.join((String(mech) for mech in self._gss.mechs)) yield from self.send_request(UInt32(len(self._gss.mechs)), mechs) else: self._conn.try_next_auth() def _finish(self): """Finish client GSS MIC authentication""" if self._gss.provides_integrity: data = self._conn.get_userauth_request_data(self._method) self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_MIC), String(self._gss.sign(data))) else: self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE)) def _process_response(self, pkttype, packet): """Process a GSS response from the server""" # pylint: disable=unused-argument mech = packet.get_string() packet.check_end() if mech not in self._gss.mechs: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Mechanism mismatch') try: token = self._gss.step() self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_TOKEN), String(token)) if self._gss.complete: self._finish() except GSSError as exc: if exc.token: self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_ERRTOK), String(exc.token)) self._conn.try_next_auth() return True def _process_token(self, pkttype, packet): """Process a GSS token from the server""" # pylint: disable=unused-argument token = packet.get_string() packet.check_end() try: token = self._gss.step(token) if token: self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_TOKEN), String(token)) if self._gss.complete: self._finish() except GSSError as exc: if exc.token: self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_ERRTOK), String(exc.token)) self._conn.try_next_auth() return True def _process_error(self, pkttype, packet): """Process a GSS error from the server""" # pylint: disable=unused-argument _ = packet.get_uint32() # major_status _ = packet.get_uint32() # minor_status msg = packet.get_string() _ = packet.get_string() # lang packet.check_end() logger.warning('GSS error from server: %s', msg.decode('utf-8', errors='ignore')) self._got_error = True return True def _process_error_token(self, pkttype, packet): """Process a GSS error token from the server""" # pylint: disable=no-self-use,unused-argument token = packet.get_string() packet.check_end() try: self._gss.step(token) except GSSError as exc: if not self._got_error: # pragma: no cover logger.warning('GSS error from server: %s', str(exc)) return True # pylint: disable=bad-whitespace packet_handlers = { MSG_USERAUTH_GSSAPI_RESPONSE: _process_response, MSG_USERAUTH_GSSAPI_TOKEN: _process_token, MSG_USERAUTH_GSSAPI_ERROR: _process_error, MSG_USERAUTH_GSSAPI_ERRTOK: _process_error_token } # pylint: enable=bad-whitespace class _ClientPublicKeyAuth(_ClientAuth): """Client side implementation of public key auth""" @asyncio.coroutine def _start(self): """Start client public key authentication""" self._keypair = yield from self._conn.public_key_auth_requested() if self._keypair is None: self._conn.try_next_auth() return yield from self.send_request(Boolean(False), String(self._keypair.algorithm), String(self._keypair.public_data)) @asyncio.coroutine def _send_signed_request(self): """Send signed public key request""" yield from self.send_request(Boolean(True), String(self._keypair.algorithm), String(self._keypair.public_data), key=self._keypair) def _process_public_key_ok(self, pkttype, packet): """Process a public key ok response""" # pylint: disable=unused-argument algorithm = packet.get_string() key_data = packet.get_string() packet.check_end() if (algorithm != self._keypair.algorithm or key_data != self._keypair.public_data): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Key mismatch') self.create_task(self._send_signed_request()) return True packet_handlers = { MSG_USERAUTH_PK_OK: _process_public_key_ok } class _ClientKbdIntAuth(_ClientAuth): """Client side implementation of keyboard-interactive auth""" @asyncio.coroutine def _start(self): """Start client keyboard interactive authentication""" submethods = yield from self._conn.kbdint_auth_requested() if submethods is None: self._conn.try_next_auth() return yield from self.send_request(String(''), String(submethods)) @asyncio.coroutine def _receive_challenge(self, name, instruction, lang, prompts): """Receive and respond to a keyboard interactive challenge""" responses = \ yield from self._conn.kbdint_challenge_received(name, instruction, lang, prompts) if responses is None: self._conn.try_next_auth() return self._conn.send_packet(Byte(MSG_USERAUTH_INFO_RESPONSE), UInt32(len(responses)), b''.join(String(r) for r in responses)) def _process_info_request(self, pkttype, packet): """Process a keyboard interactive authentication request""" # pylint: disable=unused-argument name = packet.get_string() instruction = packet.get_string() lang = packet.get_string() try: name = name.decode('utf-8') instruction = instruction.decode('utf-8') lang = lang.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid keyboard ' 'interactive info request') from None num_prompts = packet.get_uint32() prompts = [] for _ in range(num_prompts): prompt = packet.get_string() echo = packet.get_boolean() try: prompt = prompt.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid keyboard ' 'interactive info request') from None prompts.append((prompt, echo)) self.create_task(self._receive_challenge(name, instruction, lang, prompts)) return True packet_handlers = { MSG_USERAUTH_INFO_REQUEST: _process_info_request } class _ClientPasswordAuth(_ClientAuth): """Client side implementation of password auth""" def __init__(self, conn, method): super().__init__(conn, method) self._password_change = False @asyncio.coroutine def _start(self): """Start client password authentication""" password = yield from self._conn.password_auth_requested() if password is None: self._conn.try_next_auth() return yield from self.send_request(Boolean(False), String(password)) @asyncio.coroutine def _change_password(self, prompt, lang): """Start password change""" result = yield from self._conn.password_change_requested(prompt, lang) if result == NotImplemented: # Password change not supported - move on to the next auth method self._conn.try_next_auth() return old_password, new_password = result self._password_change = True yield from self.send_request(Boolean(True), String(old_password.encode('utf-8')), String(new_password.encode('utf-8'))) def auth_succeeded(self): if self._password_change: self._password_change = False self._conn.password_changed() def auth_failed(self): if self._password_change: self._password_change = False self._conn.password_change_failed() def _process_password_change(self, pkttype, packet): """Process a password change request""" # pylint: disable=unused-argument prompt = packet.get_string() lang = packet.get_string() try: prompt = prompt.decode('utf-8') lang = lang.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid password change request') from None self.auth_failed() self.create_task(self._change_password(prompt, lang)) return True packet_handlers = { MSG_USERAUTH_PASSWD_CHANGEREQ: _process_password_change } class _ServerAuth(_Auth): """Parent class for server authentication""" def __init__(self, conn, username, method, packet): self._username = username self._method = method super().__init__(conn, self._start(packet)) @asyncio.coroutine def _start(self, packet): """Abstract method for starting server authentication""" # Provided by subclass raise NotImplementedError def send_failure(self, partial_success=False): """Send a user authentication failure response""" self._conn.send_userauth_failure(partial_success) def send_success(self): """Send a user authentication success response""" self._conn.send_userauth_success() class _ServerNullAuth(_ServerAuth): """Server side implementation of null auth""" @classmethod def supported(cls, conn): """Return that null authentication is never a supported auth mode""" # pylint: disable=unused-argument return False @asyncio.coroutine def _start(self, packet): """Supported always returns false, so we never get here""" class _ServerGSSKexAuth(_ServerAuth): """Server side implementation of GSS key exchange auth""" def __init__(self, conn, username, method, packet): super().__init__(conn, username, method, packet) self._gss = conn.get_gss_context() @classmethod def supported(cls, conn): """Return whether GSS key exchange authentication is supported""" return conn.gss_kex_auth_supported() @asyncio.coroutine def _start(self, packet): """Start server GSS key exchange authentication""" mic = packet.get_string() packet.check_end() data = self._conn.get_userauth_request_data(self._method) if (self._gss.complete and self._gss.verify(data, mic) and (yield from self._conn.validate_gss_principal(self._username, self._gss.user, self._gss.host))): self.send_success() else: self.send_failure() class _ServerGSSMICAuth(_ServerAuth): """Server side implementation of GSS MIC auth""" def __init__(self, conn, username, method, packet): super().__init__(conn, username, method, packet) self._gss = conn.get_gss_context() @classmethod def supported(cls, conn): """Return whether GSS MIC authentication is supported""" return conn.gss_mic_auth_supported() @asyncio.coroutine def _start(self, packet): """Start server GSS MIC authentication""" mechs = set() n = packet.get_uint32() for _ in range(n): mechs.add(packet.get_string()) packet.check_end() match = None for mech in self._gss.mechs: if mech in mechs: match = mech break if not match: self.send_failure() return self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_RESPONSE), String(match)) @asyncio.coroutine def _finish(self): """Finish server GSS MIC authentication""" if (yield from self._conn.validate_gss_principal(self._username, self._gss.user, self._gss.host)): self.send_success() else: self.send_failure() def _process_token(self, pkttype, packet): """Process a GSS token from the client""" # pylint: disable=unused-argument token = packet.get_string() packet.check_end() try: token = self._gss.step(token) if token: self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_TOKEN), String(token)) except GSSError as exc: self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_ERROR), UInt32(exc.maj_code), UInt32(exc.min_code), String(str(exc)), String(DEFAULT_LANG)) if exc.token: self._conn.send_packet(Byte(MSG_USERAUTH_GSSAPI_ERRTOK), String(exc.token)) self.send_failure() return True def _process_exchange_complete(self, pkttype, packet): """Process a GSS exchange complete message from the client""" # pylint: disable=unused-argument packet.check_end() if self._gss.complete and not self._gss.provides_integrity: self.create_task(self._finish()) else: self.send_failure() return True def _process_error_token(self, pkttype, packet): """Process a GSS error token from the client""" # pylint: disable=unused-argument token = packet.get_string() packet.check_end() try: self._gss.step(token) except GSSError as exc: logger.warning('GSS error from client: %s', str(exc)) return True def _process_mic(self, pkttype, packet): """Process a GSS MIC from the client""" # pylint: disable=unused-argument mic = packet.get_string() packet.check_end() data = self._conn.get_userauth_request_data(self._method) if (self._gss.complete and self._gss.provides_integrity and self._gss.verify(data, mic)): self.create_task(self._finish()) else: self.send_failure() return True # pylint: disable=bad-whitespace packet_handlers = { MSG_USERAUTH_GSSAPI_TOKEN: _process_token, MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: _process_exchange_complete, MSG_USERAUTH_GSSAPI_ERRTOK: _process_error_token, MSG_USERAUTH_GSSAPI_MIC: _process_mic } # pylint: enable=bad-whitespace class _ServerPublicKeyAuth(_ServerAuth): """Server side implementation of public key auth""" @classmethod def supported(cls, conn): """Return whether public key authentication is supported""" return conn.public_key_auth_supported() @asyncio.coroutine def _start(self, packet): """Start server public key authentication""" sig_present = packet.get_boolean() algorithm = packet.get_string() key_data = packet.get_string() if sig_present: msg = packet.get_consumed_payload() signature = packet.get_string() else: msg = None signature = None packet.check_end() if (yield from self._conn.validate_public_key(self._username, key_data, msg, signature)): if sig_present: self.send_success() else: self._conn.send_packet(Byte(MSG_USERAUTH_PK_OK), String(algorithm), String(key_data)) else: self.send_failure() class _ServerKbdIntAuth(_ServerAuth): """Server side implementation of keyboard-interactive auth""" @classmethod def supported(cls, conn): """Return whether keyboard interactive authentication is supported""" return conn.kbdint_auth_supported() @asyncio.coroutine def _start(self, packet): """Start server keyboard interactive authentication""" lang = packet.get_string() submethods = packet.get_string() packet.check_end() try: lang = lang.decode('ascii') submethods = submethods.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid keyboard ' 'interactive auth request') from None challenge = yield from self._conn.get_kbdint_challenge(self._username, lang, submethods) self._send_challenge(challenge) def _send_challenge(self, challenge): """Send a keyboard interactive authentication request""" if isinstance(challenge, (tuple, list)): name, instruction, lang, prompts = challenge num_prompts = len(prompts) prompts = (String(prompt) + Boolean(echo) for prompt, echo in prompts) self._conn.send_packet(Byte(MSG_USERAUTH_INFO_REQUEST), String(name), String(instruction), String(lang), UInt32(num_prompts), *prompts) elif challenge: self.send_success() else: self.send_failure() @asyncio.coroutine def _validate_response(self, responses): """Validate a keyboard interactive authentication response""" next_challenge = \ yield from self._conn.validate_kbdint_response(self._username, responses) self._send_challenge(next_challenge) def _process_info_response(self, pkttype, packet): """Process a keyboard interactive authentication response""" # pylint: disable=unused-argument num_responses = packet.get_uint32() responses = [] for _ in range(num_responses): response = packet.get_string() try: response = response.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid keyboard ' 'interactive info response') from None responses.append(response) packet.check_end() self.create_task(self._validate_response(responses)) return True packet_handlers = { MSG_USERAUTH_INFO_RESPONSE: _process_info_response } class _ServerPasswordAuth(_ServerAuth): """Server side implementation of password auth""" @classmethod def supported(cls, conn): """Return whether password authentication is supported""" return conn.password_auth_supported() @asyncio.coroutine def _start(self, packet): """Start server password authentication""" password_change = packet.get_boolean() password = packet.get_string() new_password = packet.get_string() if password_change else b'' packet.check_end() try: password = saslprep(password.decode('utf-8')) new_password = saslprep(new_password.decode('utf-8')) except (UnicodeDecodeError, SASLPrepError): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid password auth ' 'request') from None try: if password_change: result = yield from self._conn.change_password(self._username, password, new_password) else: result = \ yield from self._conn.validate_password(self._username, password) if result: self.send_success() else: self.send_failure() except PasswordChangeRequired as exc: self._conn.send_packet(Byte(MSG_USERAUTH_PASSWD_CHANGEREQ), String(exc.prompt), String(exc.lang)) def register_auth_method(alg, client_handler, server_handler): """Register an authentication method""" _auth_methods.append(alg) _client_auth_handlers[alg] = client_handler _server_auth_handlers[alg] = server_handler def lookup_client_auth(conn, method): """Look up the client authentication method to use""" if method in _auth_methods: return _client_auth_handlers[method](conn, method) else: return None def get_server_auth_methods(conn): """Return a list of supported auth methods""" auth_methods = [] for method in _auth_methods: if _server_auth_handlers[method].supported(conn): auth_methods.append(method) return auth_methods def lookup_server_auth(conn, username, method, packet): """Look up the server authentication method to use""" handler = _server_auth_handlers.get(method) if handler and handler.supported(conn): return handler(conn, username, method, packet) else: conn.send_userauth_failure(False) return None # pylint: disable=bad-whitespace _auth_method_list = ( (b'none', _ClientNullAuth, _ServerNullAuth), (b'gssapi-keyex', _ClientGSSKexAuth, _ServerGSSKexAuth), (b'gssapi-with-mic', _ClientGSSMICAuth, _ServerGSSMICAuth), (b'publickey', _ClientPublicKeyAuth, _ServerPublicKeyAuth), (b'keyboard-interactive', _ClientKbdIntAuth, _ServerKbdIntAuth), (b'password', _ClientPasswordAuth, _ServerPasswordAuth) ) # pylint: enable=bad-whitespace for _args in _auth_method_list: register_auth_method(*_args) asyncssh-1.11.1/asyncssh/auth_keys.py000066400000000000000000000220231320320510200176120ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Parser for SSH known_hosts files""" import socket try: from .crypto import X509NamePattern _x509_available = True except ImportError: # pragma: no cover _x509_available = False from .misc import ip_address from .pattern import HostPatternList, WildcardPatternList from .public_key import KeyImportError, import_public_key from .public_key import import_certificate, import_certificate_subject class _SSHAuthorizedKeyEntry: """An entry in an SSH authorized_keys list""" def __init__(self, line): self.key = None self.cert = None self.options = {} try: self._import_key_or_cert(line) return except KeyImportError: pass line = self._parse_options(line) self._import_key_or_cert(line) def _import_key_or_cert(self, line): """Import key or certificate in this entry""" try: self.key = import_public_key(line) return except KeyImportError: pass try: self.cert = import_certificate(line) if ('cert-authority' in self.options and self.cert.subject != self.cert.issuer): raise ValueError('X.509 cert-authority entries must ' 'contain a root CA certificate') return except KeyImportError: pass if 'cert-authority' not in self.options: try: self.key = None self.cert = None self._add_subject('subject', import_certificate_subject(line)) return except KeyImportError: pass raise KeyImportError('Unrecognized key, certificate, or subject') def _set_string(self, option, value): """Set an option with a string value""" self.options[option] = value def _add_environment(self, option, value): """Add an environment key/value pair""" if value.startswith('=') or '=' not in value: raise ValueError('Invalid environment entry in authorized_keys') name, value = value.split('=', 1) self.options.setdefault(option, {})[name] = value def _add_from(self, option, value): """Add a from host pattern""" self.options.setdefault(option, []).append(HostPatternList(value)) def _add_permitopen(self, option, value): """Add a permitopen host/port pair""" try: host, port = value.rsplit(':', 1) if host.startswith('[') and host.endswith(']'): host = host[1:-1] port = None if port == '*' else int(port) except: raise ValueError('Illegal permitopen value: %s' % value) from None self.options.setdefault(option, set()).add((host, port)) def _add_principals(self, option, value): """Add a principals wildcard pattern list""" self.options.setdefault(option, []).append(WildcardPatternList(value)) def _add_subject(self, option, value): """Add an X.509 subject pattern""" if _x509_available: # pragma: no branch self.options.setdefault(option, []).append(X509NamePattern(value)) _handlers = { 'command': _set_string, 'environment': _add_environment, 'from': _add_from, 'permitopen': _add_permitopen, 'principals': _add_principals, 'subject': _add_subject } def _add_option(self): """Add an option value""" if self._option.startswith('='): raise ValueError('Missing option name in authorized_keys') if '=' in self._option: option, value = self._option.split('=', 1) handler = self._handlers.get(option) if handler: handler(self, option, value) else: self.options.setdefault(option, []).append(value) else: self.options[self._option] = True def _parse_options(self, line): """Parse options in this entry""" self._option = '' idx = 0 quoted = False escaped = False for idx, ch in enumerate(line): if escaped: self._option += ch escaped = False elif ch == '\\': escaped = True elif ch == '"': quoted = not quoted elif quoted: self._option += ch elif ch in ' \t': break elif ch == ',': self._add_option() self._option = '' else: self._option += ch self._add_option() if quoted: raise ValueError('Unbalanced quote in authorized_keys') elif escaped: raise ValueError('Unbalanced backslash in authorized_keys') return line[idx:].strip() def match_options(self, client_addr, cert_principals, cert_subject=None): """Match "from", "principals" and "subject" options in entry""" from_patterns = self.options.get('from') if from_patterns: client_host, _ = socket.getnameinfo((client_addr, 0), socket.NI_NUMERICSERV) client_ip = ip_address(client_addr) if not all(pattern.matches(client_host, client_addr, client_ip) for pattern in from_patterns): return False principal_patterns = self.options.get('principals') if cert_principals is not None and principal_patterns is not None: if not all(any(pattern.matches(principal) for principal in cert_principals) for pattern in principal_patterns): return False subject_patterns = self.options.get('subject') if cert_subject is not None and subject_patterns is not None: if not all(pattern.matches(cert_subject) for pattern in subject_patterns): return False return True class SSHAuthorizedKeys: """An SSH authorized keys list""" def __init__(self, data): self._user_entries = [] self._ca_entries = [] self._x509_entries = [] for line in data.splitlines(): line = line.strip() if not line or line.startswith('#'): continue try: entry = _SSHAuthorizedKeyEntry(line) except KeyImportError: continue if entry.key: if 'cert-authority' in entry.options: self._ca_entries.append(entry) else: self._user_entries.append(entry) else: self._x509_entries.append(entry) if (not self._user_entries and not self._ca_entries and not self._x509_entries): raise ValueError('No valid entries found') def validate(self, key, client_addr, cert_principals=None, ca=False): """Return whether a public key or CA is valid for authentication""" for entry in self._ca_entries if ca else self._user_entries: if (entry.key == key and entry.match_options(client_addr, cert_principals)): return entry.options return None def validate_x509(self, cert, client_addr): """Return whether an X.509 certificate is valid for authentication""" for entry in self._x509_entries: if (entry.cert and 'cert-authority' not in entry.options and (cert.key != entry.cert.key or cert.subject != entry.cert.subject)): continue # pragma: no cover (work around bug in coverage tool) if entry.match_options(client_addr, cert.user_principals, cert.subject): return entry.options, entry.cert return None, None def import_authorized_keys(data): """Import SSH authorized keys This function imports public keys and associated options in OpenSSH authorized keys format. :param str data: The key data to import. :returns: An :class:`SSHAuthorizedKeys` object """ return SSHAuthorizedKeys(data) def read_authorized_keys(filename): """Read SSH authorized keys from a file This function reads public keys and associated options in OpenSSH authorized_keys format from a file. :param str filename: The file to read the keys from. :returns: An :class:`SSHAuthorizedKeys` object """ with open(filename, 'r') as f: return import_authorized_keys(f.read()) asyncssh-1.11.1/asyncssh/channel.py000066400000000000000000001634771320320510200172510ustar00rootroot00000000000000# Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH channel and session handlers""" import asyncio import binascii from .constants import DEFAULT_LANG, DISC_PROTOCOL_ERROR, EXTENDED_DATA_STDERR from .constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_WINDOW_ADJUST from .constants import MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA from .constants import MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST from .constants import MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE from .constants import OPEN_CONNECT_FAILED, PTY_OP_RESERVED, PTY_OP_END from .constants import OPEN_REQUEST_X11_FORWARDING_FAILED from .constants import OPEN_REQUEST_PTY_FAILED, OPEN_REQUEST_SESSION_FAILED from .editor import SSHLineEditorChannel, SSHLineEditorSession from .misc import ChannelOpenError, DisconnectError, map_handler_name from .packet import Boolean, Byte, String, UInt32, SSHPacketHandler class SSHChannel(SSHPacketHandler): """Parent class for SSH channels""" _read_datatypes = set() _write_datatypes = set() def __init__(self, conn, loop, encoding, window, max_pktsize): """Initialize an SSH channel If encoding is set, data sent and received will be in the form of strings, converted on the wire to bytes using the specified encoding. If encoding is None, data sent and received must be provided as bytes. Window specifies the initial receive window size. Max_pktsize specifies the maximum length of a single data packet. """ self._conn = conn self._loop = loop self._session = None self._encoding = encoding self._extra = {'connection': conn} self._env = {} self._command = None self._subsystem = None self._send_state = 'closed' self._send_chan = None self._send_window = None self._send_pktsize = None self._send_paused = False self._send_buf = [] self._send_buf_len = 0 self._recv_state = 'closed' self._init_recv_window = window self._recv_window = window self._recv_pktsize = max_pktsize self._recv_paused = True self._recv_buf = [] self._recv_partial = {} self._request_queue = [] self._open_waiter = None self._request_waiters = [] self._close_event = asyncio.Event(loop=loop) self.set_write_buffer_limits() self._recv_chan = conn.add_channel(self) def get_connection(self): """Return the connection used by this channel""" return self._conn def get_loop(self): """Return the event loop used by this channel""" return self._loop def get_encoding(self): """Return the encoding used by this channel""" return self._encoding def set_encoding(self, encoding): """Set the encoding on this channel""" self._encoding = encoding def get_recv_window(self): """Return the configured receive window for this channel""" return self._init_recv_window def get_read_datatypes(self): """Return the legal read data types for this channel""" return self._read_datatypes def get_write_datatypes(self): """Return the legal write data types for this channel""" return self._write_datatypes def _cleanup(self, exc=None): """Clean up this channel""" if self._open_waiter: if not self._open_waiter.cancelled(): # pragma: no branch self._open_waiter.set_exception( ChannelOpenError(OPEN_CONNECT_FAILED, 'SSH connection closed')) self._open_waiter = None if self._request_waiters: for waiter in self._request_waiters: if not waiter.cancelled(): # pragma: no branch waiter.set_exception(exc) self._request_waiters = [] if self._session: self._session.connection_lost(exc) self._session = None self._close_event.set() if self._conn: # pragma: no branch self._conn.remove_channel(self._recv_chan) self._recv_chan = None self._conn = None def _close_send(self): """Discard unsent data and close the channel for sending""" # Discard unsent data self._send_buf = [] self._send_buf_len = 0 if self._send_state != 'closed': self._send_packet(MSG_CHANNEL_CLOSE) self._send_chan = None self._send_state = 'closed' def _discard_recv(self): """Discard unreceived data and clean up if close received""" # Discard unreceived data self._recv_buf = [] self._recv_paused = False # If recv is close_pending, we know send is already closed if self._recv_state == 'close_pending': self._recv_state = 'closed' self._loop.call_soon(self._cleanup) def _pause_resume_writing(self): """Pause or resume writing based on send buffer low/high water marks""" if self._send_paused: if self._send_buf_len <= self._send_low_water: self._send_paused = False self._session.resume_writing() else: if self._send_buf_len > self._send_high_water: self._send_paused = True self._session.pause_writing() def _flush_send_buf(self): """Flush as much data in send buffer as the send window allows""" while self._send_buf and self._send_window: pktsize = min(self._send_window, self._send_pktsize) buf, datatype = self._send_buf[0] if len(buf) > pktsize: data = buf[:pktsize] del buf[:pktsize] else: data = buf del self._send_buf[0] self._send_buf_len -= len(data) self._send_window -= len(data) if datatype is None: self._send_packet(MSG_CHANNEL_DATA, String(data)) else: self._send_packet(MSG_CHANNEL_EXTENDED_DATA, UInt32(datatype), String(data)) self._pause_resume_writing() if not self._send_buf: if self._send_state == 'eof_pending': self._send_packet(MSG_CHANNEL_EOF) self._send_state = 'eof' elif self._send_state == 'close_pending': self._close_send() def _flush_recv_buf(self, exc=None): """Flush as much data in the recv buffer as the application allows""" while self._recv_buf and not self._recv_paused: self._deliver_data(*self._recv_buf.pop(0)) if not self._recv_buf: if self._recv_state == 'eof_pending': self._recv_state = 'eof' if (not self._session.eof_received() and self._send_state == 'open'): self.write_eof() elif self._recv_state == 'close_pending': self._recv_state = 'closed' if self._recv_partial and not exc: exc = DisconnectError(DISC_PROTOCOL_ERROR, 'Unicode decode error') self._loop.call_soon(self._cleanup, exc) def _deliver_data(self, data, datatype): """Deliver incoming data to the session""" self._recv_window -= len(data) if self._recv_window < self._init_recv_window / 2: self._send_packet(MSG_CHANNEL_WINDOW_ADJUST, UInt32(self._init_recv_window - self._recv_window)) self._recv_window = self._init_recv_window if self._encoding: if datatype in self._recv_partial: encdata = self._recv_partial.pop(datatype) + data else: encdata = data while encdata: try: data = encdata.decode(self._encoding) encdata = b'' except UnicodeDecodeError as exc: if exc.start > 0: # Avoid pylint false positive # pylint: disable=invalid-slice-index data = encdata[:exc.start].decode() encdata = encdata[exc.start:] elif exc.reason == 'unexpected end of data': break else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unicode decode error') self._session.data_received(data, datatype) if encdata: self._recv_partial[datatype] = encdata else: self._session.data_received(data, datatype) def _accept_data(self, data, datatype=None): """Accept new data on the channel This method accepts new data on the channel, immediately delivering it to the session if it hasn't paused reading. If it has paused, data is buffered until reading is resumed. Data sent after the channel has been closed by the session is dropped. """ if not data: return if self._send_state in {'close_pending', 'closed'}: return if len(data) > self._recv_window: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Window exceeded') if self._recv_paused: self._recv_buf.append((data, datatype)) else: self._deliver_data(data, datatype) def _service_next_request(self): """Process next item on channel request queue""" request, packet, _ = self._request_queue[0] name = '_process_' + map_handler_name(request) + '_request' handler = getattr(self, name, None) result = handler(packet) if callable(handler) else False if result is not None: self._report_response(result) def _report_response(self, result): """Report back the response to a previously issued channel request""" request, _, want_reply = self._request_queue.pop(0) if want_reply and self._send_state not in {'close_pending', 'closed'}: if result: self._send_packet(MSG_CHANNEL_SUCCESS) else: self._send_packet(MSG_CHANNEL_FAILURE) if result and request in {'shell', 'exec', 'subsystem'}: self._session.session_started() self.resume_reading() if self._request_queue: self._service_next_request() def process_connection_close(self, exc): """Process the SSH connection closing""" if self._send_state != 'closed': self._send_state = 'closed' if self._recv_state not in {'close_pending', 'closed'}: self._recv_state = 'close_pending' self._flush_recv_buf(exc) elif self._recv_state == 'closed': self._loop.call_soon(self._cleanup, exc) def process_open(self, send_chan, send_window, send_pktsize, session): """Process a channel open request""" self._send_chan = send_chan self._send_window = send_window self._send_pktsize = send_pktsize self._conn.create_task(self._finish_open_request(session)) def _wrap_session(self, session): """Hook to optionally wrap channel and session objects""" # By default, return the original channel and session objects return self, session @asyncio.coroutine def _finish_open_request(self, session): """Finish processing a channel open request""" # pylint: disable=broad-except try: if asyncio.iscoroutine(session): session = yield from session chan, self._session = self._wrap_session(session) self._conn.send_channel_open_confirmation(self._send_chan, self._recv_chan, self._recv_window, self._recv_pktsize) self._send_state = 'open' self._recv_state = 'open' self._session.connection_made(chan) except ChannelOpenError as exc: self._conn.send_channel_open_failure(self._send_chan, exc.code, exc.reason, exc.lang) self._loop.call_soon(self._cleanup) def process_open_confirmation(self, send_chan, send_window, send_pktsize, packet): """Process a channel open confirmation""" if not self._open_waiter: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not being opened') self._send_chan = send_chan self._send_window = send_window self._send_pktsize = send_pktsize self._send_state = 'open' self._recv_state = 'open' if not self._open_waiter.cancelled(): # pragma: no branch self._open_waiter.set_result(packet) self._open_waiter = None def process_open_failure(self, code, reason, lang): """Process a channel open failure""" if not self._open_waiter: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not being opened') if not self._open_waiter.cancelled(): # pragma: no branch self._open_waiter.set_exception( ChannelOpenError(code, reason, lang)) self._open_waiter = None self._loop.call_soon(self._cleanup) def _process_window_adjust(self, pkttype, packet): """Process a send window adjustment""" # pylint: disable=unused-argument if self._recv_state not in {'open', 'eof_pending', 'eof'}: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not open') adjust = packet.get_uint32() packet.check_end() self._send_window += adjust self._flush_send_buf() def _process_data(self, pkttype, packet): """Process incoming data""" # pylint: disable=unused-argument if self._recv_state != 'open': raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not open for sending') data = packet.get_string() packet.check_end() self._accept_data(data) def _process_extended_data(self, pkttype, packet): """Process incoming extended data""" # pylint: disable=unused-argument if self._recv_state != 'open': raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not open for sending') datatype = packet.get_uint32() data = packet.get_string() packet.check_end() if datatype not in self._read_datatypes: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid extended data type') self._accept_data(data, datatype) def _process_eof(self, pkttype, packet): """Process an incoming end of file""" # pylint: disable=unused-argument if self._recv_state != 'open': raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not open for sending') packet.check_end() self._recv_state = 'eof_pending' self._flush_recv_buf() def _process_close(self, pkttype, packet): """Process an incoming channel close""" # pylint: disable=unused-argument if self._recv_state not in {'open', 'eof_pending', 'eof'}: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not open') packet.check_end() self._close_send() self._recv_state = 'close_pending' self._flush_recv_buf() def _process_request(self, pkttype, packet): """Process an incoming channel request""" # pylint: disable=unused-argument if self._recv_state not in {'open', 'eof_pending', 'eof'}: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Channel not open') request = packet.get_string() want_reply = packet.get_boolean() try: request = request.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid channel request') from None self._request_queue.append((request, packet, want_reply)) if len(self._request_queue) == 1: self._service_next_request() def _process_response(self, pkttype, packet): """Process a success or failure response""" packet.check_end() if self._request_waiters: waiter = self._request_waiters.pop(0) if not waiter.cancelled(): # pragma: no branch waiter.set_result(pkttype == MSG_CHANNEL_SUCCESS) else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected channel response') packet_handlers = { MSG_CHANNEL_WINDOW_ADJUST: _process_window_adjust, MSG_CHANNEL_DATA: _process_data, MSG_CHANNEL_EXTENDED_DATA: _process_extended_data, MSG_CHANNEL_EOF: _process_eof, MSG_CHANNEL_CLOSE: _process_close, MSG_CHANNEL_REQUEST: _process_request, MSG_CHANNEL_SUCCESS: _process_response, MSG_CHANNEL_FAILURE: _process_response } @asyncio.coroutine def _open(self, chantype, *args): """Make a request to open the channel""" if self._send_state != 'closed': raise OSError('Channel already open') self._open_waiter = asyncio.Future(loop=self._loop) self._conn.send_packet(Byte(MSG_CHANNEL_OPEN), String(chantype), UInt32(self._recv_chan), UInt32(self._recv_window), UInt32(self._recv_pktsize), *args) return (yield from self._open_waiter) def _send_packet(self, pkttype, *args): """Send a packet on the channel""" if self._send_chan is None: # pragma: no cover return self._conn.send_packet(Byte(pkttype), UInt32(self._send_chan), *args) def _send_request(self, request, *args, want_reply=False): """Send a channel request""" self._send_packet(MSG_CHANNEL_REQUEST, String(request), Boolean(want_reply), *args) @asyncio.coroutine def _make_request(self, request, *args): """Make a channel request and wait for the response""" if self._send_chan is None: return False waiter = asyncio.Future(loop=self._loop) self._request_waiters.append(waiter) self._send_request(request, *args, want_reply=True) return (yield from waiter) def abort(self): """Forcibly close the channel This method can be called to forcibly close the channel, after which no more data can be sent or received. Any unsent buffered data and any incoming data in flight will be discarded. """ if self._send_state not in {'close_pending', 'closed'}: # Send an immediate close, discarding unsent data self._close_send() if self._recv_state != 'closed': # Discard unreceived data self._discard_recv() def close(self): """Cleanly close the channel This method can be called to cleanly close the channel, after which no more data can be sent or received. Any unsent buffered data will be flushed asynchronously before the channel is closed. """ if self._send_state not in {'close_pending', 'closed'}: # Send a close only after sending unsent data self._send_state = 'close_pending' self._flush_send_buf() if self._recv_state != 'closed': # Discard unreceived data self._discard_recv() @asyncio.coroutine def wait_closed(self): """Wait for this channel to close This method is a coroutine which can be called to block until this channel has finished closing. """ yield from self._close_event.wait() def get_extra_info(self, name, default=None): """Get additional information about the channel This method returns extra information about the channel once it is established. Supported values include ``'connection'`` to return the SSH connection this channel is running over plus all of the values supported on that connection. For TCP channels, the values ``'local_peername'`` and ``'remote_peername'`` are added to return the local and remote host and port information for the tunneled TCP connection. For UNIX channels, the values ``'local_peername'`` and ``'remote_peername'`` are added to return the local and remote path information for the tunneled UNIX domain socket connection. Since UNIX domain sockets provide no "source" address, only one of these will be filled in. See :meth:`get_extra_info() ` on :class:`SSHClientConnection` for more information. """ return self._extra.get(name, self._conn.get_extra_info(name, default) if self._conn else default) def can_write_eof(self): """Return whether the channel supports :meth:`write_eof` This method always returns ``True``. """ # pylint: disable=no-self-use return True def get_write_buffer_size(self): """Return the current size of the channel's output buffer This method returns how many bytes are currently in the channel's output buffer waiting to be written. """ return self._send_buf_len def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control This method sets the limits used when deciding when to call the ``pause_writing()`` and ``resume_writing()`` methods on SSH sessions. Writing will be paused when the write buffer size exceeds the high-water mark, and resumed when the write buffer size equals or drops below the low-water mark. """ if high is None: high = 4*low if low is not None else 65536 if low is None: low = high // 4 if not 0 <= low <= high: raise ValueError('high (%r) must be >= low (%r) must be >= 0' % (high, low)) self._send_high_water = high self._send_low_water = low self._pause_resume_writing() def write(self, data, datatype=None): """Write data on the channel This method can be called to send data on the channel. If an encoding was specified when the channel was created, the data should be provided as a string and will be converted using that encoding. Otherwise, the data should be provided as bytes. An extended data type can optionally be provided. For instance, this is used from a :class:`SSHServerSession` to write data to ``stderr``. :param data: The data to send on the channel :param int datatype: (optional) The extended data type of the data, from :ref:`extended data types ` :type data: str or bytes :raises: :exc:`OSError` if the channel isn't open for sending or the extended data type is not valid for this type of channel """ if self._send_state != 'open': raise BrokenPipeError('Channel not open for sending') if datatype is not None and datatype not in self._write_datatypes: raise OSError('Invalid extended data type') if len(data) == 0: return if self._encoding: data = data.encode(self._encoding) self._send_buf.append((bytearray(data), datatype)) self._send_buf_len += len(data) self._flush_send_buf() def writelines(self, list_of_data, datatype=None): """Write a list of data bytes on the channel This method can be called to write a list (or any iterable) of data bytes to the channel. It is functionality equivalent to calling :meth:`write` on each element in the list. :param list_of_data: The data to send on the channel :param int datatype: (optional) The extended data type of the data, from :ref:`extended data types ` :type list_of_data: iterable of str or bytes objects :raises: :exc:`OSError` if the channel isn't open for sending or the extended data type is not valid for this type of channel """ sep = '' if self._encoding else b'' return self.write(sep.join(list_of_data), datatype) def write_eof(self): """Write EOF on the channel This method sends an end-of-file indication on the channel, after which no more data can be sent. The channel remains open, though, and data may still be sent in the other direction. :raises: :exc:`OSError` if the channel isn't open for sending """ if self._send_state == 'open': self._send_state = 'eof_pending' self._flush_send_buf() def pause_reading(self): """Pause delivery of incoming data This method is used to temporarily suspend delivery of incoming channel data. After this call, incoming data will no longer be delivered until :meth:`resume_reading` is called. Data will be buffered locally up to the configured SSH channel window size, but window updates will no longer be sent, eventually causing back pressure on the remote system. .. note:: Channel close notifications are not suspended by this call. If the remote system closes the channel while delivery is suspended, the channel will be closed even though some buffered data may not have been delivered. """ self._recv_paused = True def resume_reading(self): """Resume delivery of incoming data This method can be called to resume delivery of incoming data which was suspended by a call to :meth:`pause_reading`. As soon as this method is called, any buffered data will be delivered immediately. A pending end-of-file notication may also be delivered if one was queued while reading was paused. """ if self._recv_paused: self._recv_paused = False self._flush_recv_buf() def get_environment(self): """Return the environment for this session This method returns the environment set by the client when the session was opened. On the server, calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :returns: A dictionary containing the environment variables set by the client """ return self._env def get_command(self): """Return the command the client requested to execute, if any This method returns the command the client requested to execute when the session was opened, if any. If the client did not request that a command be executed, this method will return ``None``. On the server, alls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. """ return self._command def get_subsystem(self): """Return the subsystem the client requested to open, if any This method returns the subsystem the client requested to open when the session was opened, if any. If the client did not request that a subsystem be opened, this method will return ``None``. On the server, calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. """ return self._subsystem class SSHClientChannel(SSHChannel): """SSH client channel""" _read_datatypes = {EXTENDED_DATA_STDERR} def __init__(self, conn, loop, encoding, window, max_pktsize): super().__init__(conn, loop, encoding, window, max_pktsize) self._exit_status = None self._exit_signal = None def _cleanup(self, exc=None): """Clean up this channel""" if self._conn: # pragma: no branch self._conn.detach_x11_listener(self) super()._cleanup(exc) @asyncio.coroutine def create(self, session_factory, command, subsystem, env, term_type, term_size, term_modes, x11_forwarding, x11_display, x11_auth_path, x11_single_connection, agent_forwarding): """Create an SSH client session""" packet = yield from self._open(b'session') # Client sessions should have no extra data in the open confirmation packet.check_end() self._session = session_factory() self._session.connection_made(self) self._env = env self._command = command self._subsystem = subsystem for name, value in env.items(): self._send_request(b'env', String(str(name)), String(str(value))) if term_type: if not term_size: width = height = pixwidth = pixheight = 0 elif len(term_size) == 2: width, height = term_size pixwidth = pixheight = 0 elif len(term_size) == 4: width, height, pixwidth, pixheight = term_size else: raise ValueError('If set, terminal size must be a tuple of ' '2 or 4 integers') modes = b'' for mode, value in term_modes.items(): if mode <= PTY_OP_END or mode >= PTY_OP_RESERVED: raise ValueError('Invalid pty mode: %s' % mode) modes += Byte(mode) + UInt32(value) modes += Byte(PTY_OP_END) if not (yield from self._make_request(b'pty-req', String(term_type), UInt32(width), UInt32(height), UInt32(pixwidth), UInt32(pixheight), String(modes))): self.close() raise ChannelOpenError(OPEN_REQUEST_PTY_FAILED, 'PTY request failed') if x11_forwarding: try: auth_proto, remote_auth, screen = \ yield from self._conn.attach_x11_listener( self, x11_display, x11_auth_path, x11_single_connection) except ValueError as exc: raise ChannelOpenError(OPEN_REQUEST_X11_FORWARDING_FAILED, str(exc)) from None result = yield from self._make_request( b'x11-req', Boolean(x11_single_connection), String(auth_proto), String(binascii.b2a_hex(remote_auth)), UInt32(screen)) if not result: self._conn.detach_x11_listener(self) raise ChannelOpenError(OPEN_REQUEST_X11_FORWARDING_FAILED, 'X11 forwarding request failed') if agent_forwarding: self._send_request(b'auth-agent-req@openssh.com') if command: result = yield from self._make_request(b'exec', String(command)) elif subsystem: result = yield from self._make_request(b'subsystem', String(subsystem)) else: result = yield from self._make_request(b'shell') if not result: self.close() raise ChannelOpenError(OPEN_REQUEST_SESSION_FAILED, 'Session request failed') self._session.session_started() self.resume_reading() return self, self._session def _process_xon_xoff_request(self, packet): """Process a request to set up XON/XOFF processing""" client_can_do = packet.get_boolean() packet.check_end() self._session.xon_xoff_requested(client_can_do) return True def _process_exit_status_request(self, packet): """Process a request to deliver exit status""" status = packet.get_uint32() & 0xff packet.check_end() self._exit_status = status self._session.exit_status_received(status) return True def _process_exit_signal_request(self, packet): """Process a request to deliver an exit signal""" signal = packet.get_string() core_dumped = packet.get_boolean() msg = packet.get_string() lang = packet.get_string() packet.check_end() try: signal = signal.decode('ascii') msg = msg.decode('utf-8') lang = lang.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid exit signal request') from None self._exit_signal = (signal, core_dumped, msg, lang) self._session.exit_signal_received(signal, core_dumped, msg, lang) return True def get_exit_status(self): """Return the session's exit status This method returns the exit status of the session if one has been sent. If an exit signal was received, this method returns -1 and the exit signal information can be collected by calling :meth:`get_exit_signal`. If neither has been sent, this method returns ``None``. """ if self._exit_status is not None: return self._exit_status elif self._exit_signal: return -1 else: return None def get_exit_signal(self): """Return the session's exit signal, if one was sent This method returns information about the exit signal sent on this session. If an exit signal was sent, a tuple is returned containing the signal name, a boolean for whether a core dump occurred, a message associated with the signal, and the language the message was in. If no exit signal was sent, ``None`` is returned. """ return self._exit_signal def change_terminal_size(self, width, height, pixwidth=0, pixheight=0): """Change the terminal window size for this session This method changes the width and height of the terminal associated with this session. :param int width: The width of the terminal in characters :param int height: The height of the terminal in characters :param int pixwidth: (optional) The width of the terminal in pixels :param int pixheight: (optional) The height of the terminal in pixels """ self._send_request(b'window-change', UInt32(width), UInt32(height), UInt32(pixwidth), UInt32(pixheight)) def send_break(self, msec): """Send a break to the remote process This method requests that the server perform a break operation on the remote process or service as described in :rfc:`4335`. :param int msec: The duration of the break in milliseconds :raises: :exc:`OSError` if the channel is not open """ self._send_request(b'break', UInt32(msec)) def send_signal(self, signal): """Send a signal to the remote process This method can be called to deliver a signal to the remote process or service. Signal names should be as described in section 6.10 of :rfc:`RFC 4254 <4254#section-6.10>`. .. note:: OpenSSH's SSH server implementation does not currently support this message, so attempts to use :meth:`send_signal`, :meth:`terminate`, or :meth:`kill` with an OpenSSH SSH server will end up being ignored. This is currently being tracked in OpenSSH `bug 1424`__. __ https://bugzilla.mindrot.org/show_bug.cgi?id=1424 :param str signal: The signal to deliver :raises: :exc:`OSError` if the channel is not open """ self._send_request(b'signal', String(signal)) def terminate(self): """Terminate the remote process This method can be called to terminate the remote process or service by sending it a ``TERM`` signal. :raises: :exc:`OSError` if the channel is not open """ self.send_signal('TERM') def kill(self): """Forcibly kill the remote process This method can be called to forcibly stop the remote process or service by sending it a ``KILL`` signal. :raises: :exc:`OSError` if the channel is not open """ self.send_signal('KILL') class SSHServerChannel(SSHChannel): """SSH server channel""" _write_datatypes = {EXTENDED_DATA_STDERR} def __init__(self, conn, loop, allow_pty, line_editor, line_history, encoding, window, max_pktsize): """Initialize an SSH server channel""" super().__init__(conn, loop, encoding, window, max_pktsize) self._env = conn.get_key_option('environment', {}) self._allow_pty = allow_pty self._line_editor = line_editor self._line_history = line_history self._term_type = None self._term_size = (0, 0, 0, 0) self._term_modes = {} self._x11_display = None def _cleanup(self, exc=None): """Clean up this channel""" self._conn.detach_x11_listener(self) super()._cleanup(exc) def _wrap_session(self, session): """Wrap a line editor around the session if enabled""" if self._line_editor: chan = SSHLineEditorChannel(self, session, self._line_history) session = SSHLineEditorSession(chan, session) else: chan = self return chan, session def _process_pty_req_request(self, packet): """Process a request to open a pseudo-terminal""" term_type = packet.get_string() width = packet.get_uint32() height = packet.get_uint32() pixwidth = packet.get_uint32() pixheight = packet.get_uint32() modes = packet.get_string() packet.check_end() if not self._allow_pty or \ not self._conn.check_key_permission('pty') or \ not self._conn.check_certificate_permission('pty'): return False try: term_type = term_type.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid pty request') from None term_size = (width, height, pixwidth, pixheight) term_modes = {} idx = 0 while idx < len(modes): mode = modes[idx] idx += 1 if mode == PTY_OP_END or mode >= PTY_OP_RESERVED: break if idx+4 <= len(modes): term_modes[mode] = int.from_bytes(modes[idx:idx+4], 'big') idx += 4 else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid pty modes string') result = self._session.pty_requested(self._term_type, self._term_size, self._term_modes) if result: self._term_type = term_type self._term_size = term_size self._term_modes = term_modes return result def _process_x11_req_request(self, packet): """Process request to enable X11 forwarding""" _ = packet.get_boolean() # single_connection auth_proto = packet.get_string() auth_data = packet.get_string() screen = packet.get_uint32() packet.check_end() try: auth_data = binascii.a2b_hex(auth_data) except binascii.Error: return False self._conn.create_task(self._finish_x11_req_request(auth_proto, auth_data, screen)) @asyncio.coroutine def _finish_x11_req_request(self, auth_proto, auth_data, screen): """Finish processing request to enable X11 forwarding""" self._x11_display = yield from self._conn.attach_x11_listener( self, auth_proto, auth_data, screen) self._report_response(bool(self._x11_display)) def _process_auth_agent_req_at_openssh_dot_com_request(self, packet): """Process a request to enable ssh-agent forwarding""" packet.check_end() self._conn.create_task(self._finish_agent_req_request()) @asyncio.coroutine def _finish_agent_req_request(self): """Finish processing request to enable agent forwarding""" self._report_response((yield from self._conn.create_agent_listener())) def _process_env_request(self, packet): """Process a request to set an environment variable""" name = packet.get_string() value = packet.get_string() packet.check_end() try: name = name.decode('utf-8') value = value.decode('utf-8') except UnicodeDecodeError: return False self._env[name] = value return True def _start_session(self, command=None, subsystem=None): """Tell the session what type of channel is being requested""" forced_command = self._conn.get_certificate_option('force-command') if forced_command is None: forced_command = self._conn.get_key_option('command') if forced_command is not None: command = forced_command if command is not None: self._command = command result = self._session.exec_requested(command) elif subsystem is not None: self._subsystem = subsystem result = self._session.subsystem_requested(subsystem) else: result = self._session.shell_requested() return result def _process_shell_request(self, packet): """Process a request to open a shell""" packet.check_end() return self._start_session() def _process_exec_request(self, packet): """Process a request to execute a command""" command = packet.get_string() packet.check_end() try: command = command.decode('utf-8') except UnicodeDecodeError: return False return self._start_session(command=command) def _process_subsystem_request(self, packet): """Process a request to open a subsystem""" subsystem = packet.get_string() packet.check_end() try: subsystem = subsystem.decode('ascii') except UnicodeDecodeError: return False return self._start_session(subsystem=subsystem) def _process_window_change_request(self, packet): """Process a request to change the window size""" width = packet.get_uint32() height = packet.get_uint32() pixwidth = packet.get_uint32() pixheight = packet.get_uint32() packet.check_end() self._term_size = (width, height, pixwidth, pixheight) self._session.terminal_size_changed(width, height, pixwidth, pixheight) return True def _process_signal_request(self, packet): """Process a request to send a signal""" signal = packet.get_string() packet.check_end() try: signal = signal.decode('ascii') except UnicodeDecodeError: return False self._session.signal_received(signal) return True def _process_break_request(self, packet): """Process a request to send a break""" msec = packet.get_uint32() packet.check_end() return self._session.break_received(msec) def get_terminal_type(self): """Return the terminal type for this session This method returns the terminal type set by the client when the session was opened. If the client didn't request a pseudo-terminal, this method will return ``None``. Calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :returns: A str containing the terminal type or ``None`` if no pseudo-terminal was requested """ return self._term_type def get_terminal_size(self): """Return terminal size information for this session This method returns the latest terminal size information set by the client. If the client didn't set any terminal size information, all values returned will be zero. Calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. Also see :meth:`terminal_size_changed() ` or the :exc:`TerminalSizeChanged` exception for how to get notified when the terminal size changes. :returns: A tuple of four integers containing the width and height of the terminal in characters and the width and height of the terminal in pixels """ return self._term_size def get_terminal_mode(self, mode): """Return the requested TTY mode for this session This method looks up the value of a POSIX terminal mode set by the client when the session was opened. If the client didn't request a pseudo-terminal or didn't set the requested TTY mode opcode, this method will return ``None``. Calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :param int mode: POSIX terminal mode taken from :ref:`POSIX terminal modes ` to look up :returns: An int containing the value of the requested POSIX terminal mode or ``None`` if the requested mode was not set """ return self._term_modes.get(mode) def get_x11_display(self): """Return the display to use for X11 forwarding When X11 forwarding has been requested by the client, this method returns the X11 display which should be used to open a forwarded connection. If the client did not request X11 forwarding, this method returns ``None``. :returns: A str containing the X11 display or ``None`` if X11 fowarding was not requested """ return self._x11_display def get_agent_path(self): """Return the path of the ssh-agent listening socket When agent forwarding has been requested by the client, this method returns the path of the listening socket which should be used to open a forwarded agent connection. If the client did not request agent forwarding, this method returns ``None``. :returns: A str containing the ssh-agent socket path or ``None`` if agent fowarding was not requested """ return self._conn.get_agent_path() def set_xon_xoff(self, client_can_do): """Set whether the client should enable XON/XOFF flow control This method can be called to tell the client whether or not to enable XON/XOFF flow control, indicating that it should intercept Control-S and Control-Q coming from its local terminal to pause and resume output, respectively. Applications should set client_can_do to ``True`` to enable this functionality or to ``False`` to tell the client to forward Control-S and Control-Q through as normal input. :param bool client_can_do: Whether or not the client should enable XON/XOFF flow control """ self._send_request(b'xon-xoff', Boolean(client_can_do)) def write_stderr(self, data): """Write output to stderr This method can be called to send output to the client which is intended to be displayed on stderr. If an encoding was specified when the channel was created, the data should be provided as a string and will be converted using that encoding. Otherwise, the data should be provided as bytes. :param data: The data to send to stderr :type data: str or bytes :raises: :exc:`OSError` if the channel isn't open for sending """ self.write(data, EXTENDED_DATA_STDERR) def writelines_stderr(self, list_of_data): """Write a list of data bytes to stderr This method can be called to write a list (or any iterable) of data bytes to the channel. It is functionality equivalent to calling :meth:`write_stderr` on each element in the list. """ self.writelines(list_of_data, EXTENDED_DATA_STDERR) def exit(self, status): """Send exit status and close the channel This method can be called to report an exit status for the process back to the client and close the channel. A zero exit status is generally returned when the operation was successful. After reporting the status, the channel is closed. :param int status: The exit status to report to the client :raises: :exc:`OSError` if the channel isn't open """ if self._send_state not in {'close_pending', 'closed'}: self._send_request(b'exit-status', UInt32(status & 0xff)) self.close() def exit_with_signal(self, signal, core_dumped=False, msg='', lang=DEFAULT_LANG): """Send exit signal and close the channel This method can be called to report that the process terminated abnormslly with a signal. A more detailed error message may also provided, along with an indication of whether or not the process dumped core. After reporting the signal, the channel is closed. :param str signal: The signal which caused the process to exit :param bool core_dumped: (optional) Whether or not the process dumped core :param str msg: (optional) Details about what error occurred :param str lang: (optional) The language the error message is in :raises: :exc:`OSError` if the channel isn't open """ if self._send_state not in {'close_pending', 'closed'}: self._send_request(b'exit-signal', String(signal), Boolean(core_dumped), String(msg), String(lang)) self.close() class SSHForwardChannel(SSHChannel): """SSH channel for forwarding TCP and UNIX domain connections""" @asyncio.coroutine def _finish_open_request(self, session): """Finish processing a forward channel open request""" yield from super()._finish_open_request(session) if self._session: self._session.session_started() self.resume_reading() @asyncio.coroutine def _open(self, session_factory, chantype, *args): """Open a forward channel""" packet = yield from super()._open(chantype, *args) # Forward channels should have no extra data in the open confirmation packet.check_end() self._session = session_factory() self._session.connection_made(self) self._session.session_started() self.resume_reading() return self._session class SSHTCPChannel(SSHForwardChannel): """SSH TCP channel""" @asyncio.coroutine def _open_tcp(self, session_factory, chantype, host, port, orig_host, orig_port): """Open a TCP channel""" self._extra['peername'] = None self._extra['local_peername'] = (orig_host, orig_port) self._extra['remote_peername'] = (host, port) return (yield from super()._open(session_factory, chantype, String(host), UInt32(port), String(orig_host), UInt32(orig_port))) @asyncio.coroutine def connect(self, session_factory, host, port, orig_host, orig_port): """Create a new outbound TCP session""" return (yield from self._open_tcp(session_factory, b'direct-tcpip', host, port, orig_host, orig_port)) @asyncio.coroutine def accept(self, session_factory, host, port, orig_host, orig_port): """Create a new forwarded TCP session""" return (yield from self._open_tcp(session_factory, b'forwarded-tcpip', host, port, orig_host, orig_port)) def set_inbound_peer_names(self, dest_host, dest_port, orig_host, orig_port): """Set local and remote peer names for inbound connections""" self._extra['local_peername'] = (dest_host, dest_port) self._extra['remote_peername'] = (orig_host, orig_port) class SSHUNIXChannel(SSHForwardChannel): """SSH UNIX channel""" @asyncio.coroutine def _open_unix(self, session_factory, chantype, path, *args): """Open a UNIX channel""" self._extra['local_peername'] = '' self._extra['remote_peername'] = path return (yield from super()._open(session_factory, chantype, String(path), *args)) @asyncio.coroutine def connect(self, session_factory, path): """Create a new outbound UNIX session""" # OpenSSH appears to have a bug which requires an originator # host and port to be sent after the path name to connect to # when opening a direct streamlocal channel. return (yield from self._open_unix(session_factory, b'direct-streamlocal@openssh.com', path, String(''), UInt32(0))) @asyncio.coroutine def accept(self, session_factory, path): """Create a new forwarded UNIX session""" return (yield from self._open_unix(session_factory, b'forwarded-streamlocal@openssh.com', path, String(''))) def set_inbound_peer_names(self, dest_path): """Set local and remote peer names for inbound connections""" self._extra['local_peername'] = dest_path self._extra['remote_peername'] = '' class SSHX11Channel(SSHForwardChannel): """SSH X11 channel""" @asyncio.coroutine def open(self, session_factory, orig_host, orig_port): """Open an SSH X11 channel""" self._extra['local_peername'] = (orig_host, orig_port) self._extra['remote_peername'] = None return (yield from self._open(session_factory, b'x11', String(orig_host), UInt32(orig_port))) def set_inbound_peer_names(self, orig_host, orig_port): """Set local and remote peer name for inbound connections""" self._extra['local_peername'] = None self._extra['remote_peername'] = (orig_host, orig_port) class SSHAgentChannel(SSHForwardChannel): """SSH agent channel""" @asyncio.coroutine def open(self, session_factory): """Open an SSH agent channel""" return (yield from self._open(session_factory, b'auth-agent@openssh.com')) asyncssh-1.11.1/asyncssh/cipher.py000066400000000000000000000054621320320510200171000ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Symmetric key encryption handlers""" from .crypto import lookup_cipher _enc_algs = [] _enc_params = {} _enc_ciphers = {} def register_encryption_alg(alg, cipher_name, mode_name, key_size, initial_bytes): """Register an encryption algorithm""" cipher = lookup_cipher(cipher_name, mode_name) if cipher: # pragma: no branch _enc_algs.append(alg) _enc_params[alg] = (key_size, cipher.iv_size, cipher.block_size, cipher.mode_name) _enc_ciphers[alg] = (cipher, initial_bytes) def get_encryption_algs(): """Return a list of available encryption algorithms""" return _enc_algs def get_encryption_params(alg): """Get parameters of an encryption algorithm This function returns the key, iv, and block sizes of an encryption algorithm. """ return _enc_params[alg] def get_cipher(alg, key, iv=None): """Return an instance of a cipher This function returns a cipher object initialized with the specified key and iv that can be used for data encryption and decryption. """ cipher, initial_bytes = _enc_ciphers[alg] return cipher.new(key, iv, initial_bytes) # pylint: disable=bad-whitespace register_encryption_alg(b'chacha20-poly1305@openssh.com', 'chacha20-poly1305', 'chacha', 64, 0) register_encryption_alg(b'aes256-ctr', 'aes', 'ctr', 32, 0) register_encryption_alg(b'aes192-ctr', 'aes', 'ctr', 24, 0) register_encryption_alg(b'aes128-ctr', 'aes', 'ctr', 16, 0) register_encryption_alg(b'aes256-gcm@openssh.com', 'aes', 'gcm', 32, 0) register_encryption_alg(b'aes128-gcm@openssh.com', 'aes', 'gcm', 16, 0) register_encryption_alg(b'aes256-cbc', 'aes', 'cbc', 32, 0) register_encryption_alg(b'aes192-cbc', 'aes', 'cbc', 24, 0) register_encryption_alg(b'aes128-cbc', 'aes', 'cbc', 16, 0) register_encryption_alg(b'3des-cbc', 'des3', 'cbc', 24, 0) register_encryption_alg(b'blowfish-cbc', 'blowfish', 'cbc', 16, 0) register_encryption_alg(b'cast128-cbc', 'cast', 'cbc', 16, 0) register_encryption_alg(b'arcfour256', 'arc4', None, 32, 1536) register_encryption_alg(b'arcfour128', 'arc4', None, 16, 1536) register_encryption_alg(b'arcfour', 'arc4', None, 16, 0) asyncssh-1.11.1/asyncssh/client.py000066400000000000000000000275551320320510200171130ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH client protocol handler""" class SSHClient: """SSH client protocol handler Applications should subclass this when implementing an SSH client. The functions listed below should be overridden to define application-specific behavior. In particular, the method :meth:`auth_completed` should be defined to open the desired SSH channels on this connection once authentication has been completed. For simple password or public key based authentication, nothing needs to be defined here if the password or client keys are passed in when the connection is created. However, to prompt interactively or otherwise dynamically select these values, the methods :meth:`password_auth_requested` and/or :meth:`public_key_auth_requested` can be defined. Keyboard-interactive authentication is also supported via :meth:`kbdint_auth_requested` and :meth:`kbdint_challenge_received`. If the server sends an authentication banner, the method :meth:`auth_banner_received` will be called. If the server requires a password change, the method :meth:`password_change_requested` will be called, followed by either :meth:`password_changed` or :meth:`password_change_failed` depending on whether the password change is successful. .. note:: The authentication callbacks described here can be defined as coroutines. However, they may be cancelled if they are running when the SSH connection is closed by the server. If they attempt to catch the CancelledError exception to perform cleanup, they should make sure to re-raise it to allow AsyncSSH to finish its own cleanup. """ # pylint: disable=no-self-use,unused-argument def connection_made(self, connection): """Called when a connection is made This method is called as soon as the TCP connection completes. The connection parameter should be stored if needed for later use. :param connection: The connection which was successfully opened :type connection: :class:`SSHClientConnection` """ pass # pragma: no cover def connection_lost(self, exc): """Called when a connection is lost or closed This method is called when a connection is closed. If the connection is shut down cleanly, *exc* will be ``None``. Otherwise, it will be an exception explaining the reason for the disconnect. :param exc: The exception which caused the connection to close, or ``None`` if the connection closed cleanly :type exc: :class:`Exception` """ pass # pragma: no cover def debug_msg_received(self, msg, lang, always_display): """A debug message was received on this connection This method is called when the other end of the connection sends a debug message. Applications should implement this method if they wish to process these debug messages. :param str msg: The debug message sent :param str lang: The language the message is in :param bool always_display: Whether or not to display the message """ pass # pragma: no cover def auth_banner_received(self, msg, lang): """An incoming authentication banner was received This method is called when the server sends a banner to display during authentication. Applications should implement this method if they wish to do something with the banner. :param str msg: The message the server wanted to display :param str lang: The language the message is in """ pass # pragma: no cover def auth_completed(self): """Authentication was completed successfully This method is called when authentication has completed succesfully. Applications may use this method to create whatever client sessions and direct TCP/IP or UNIX domain connections are needed and/or set up listeners for incoming TCP/IP or UNIX domain connections coming from the server. """ pass # pragma: no cover def public_key_auth_requested(self): """Public key authentication has been requested This method should return a private key corresponding to the user that authentication is being attempted for. This method may be called multiple times and can return a different key to try each time it is called. When there are no keys left to try, it should return ``None`` to indicate that some other authentication method should be tried. If client keys were provided when the connection was opened, they will be tried before this method is called. If blocking operations need to be performed to determine the key to authenticate with, this method may be defined as a coroutine. :returns: A key as described in :ref:`SpecifyingPrivateKeys` or ``None`` to move on to another authentication method """ return None # pragma: no cover def password_auth_requested(self): """Password authentication has been requested This method should return a string containing the password corresponding to the user that authentication is being attempted for. It may be called multiple times and can return a different password to try each time, but most servers have a limit on the number of attempts allowed. When there's no password left to try, this method should return ``None`` to indicate that some other authentication method should be tried. If a password was provided when the connection was opened, it will be tried before this method is called. If blocking operations need to be performed to determine the password to authenticate with, this method may be defined as a coroutine. :returns: A string containing the password to authenticate with or ``None`` to move on to another authentication method """ return None # pragma: no cover def password_change_requested(self, prompt, lang): """A password change has been requested This method is called when password authentication was attempted and the user's password was expired on the server. To request a password change, this method should return a tuple or two strings containing the old and new passwords. Otherwise, it should return ``NotImplemented``. If blocking operations need to be performed to determine the passwords to authenticate with, this method may be defined as a coroutine. By default, this method returns ``NotImplemented``. :param str prompt: The prompt requesting that the user enter a new password :param str lang: The language that the prompt is in :returns: A tuple of two strings containing the old and new passwords or ``NotImplemented`` if password changes aren't supported """ return NotImplemented # pragma: no cover def password_changed(self): """The requested password change was successful This method is called to indicate that a requested password change was successful. It is generally followed by a call to :meth:`auth_completed` since this means authentication was also successful. """ pass # pragma: no cover def password_change_failed(self): """The requested password change has failed This method is called to indicate that a requested password change failed, generally because the requested new password doesn't meet the password criteria on the remote system. After this method is called, other forms of authentication will automatically be attempted. """ pass # pragma: no cover def kbdint_auth_requested(self): """Keyboard-interactive authentication has been requested This method should return a string containing a comma-separated list of submethods that the server should use for keyboard-interactive authentication. An empty string can be returned to let the server pick the type of keyboard-interactive authentication to perform. If keyboard-interactive authentication is not supported, ``None`` should be returned. By default, keyboard-interactive authentication is supported if a password was provided when the :class:`SSHClient` was created and it hasn't been sent yet. If the challenge is not a password challenge, this authentication will fail. This method and the :meth:`kbdint_challenge_received` method can be overridden if other forms of challenge should be supported. If blocking operations need to be performed to determine the submethods to request, this method may be defined as a coroutine. :returns: A string containing the submethods the server should use for authentication or ``None`` to move on to another authentication method """ return NotImplemented # pragma: no cover def kbdint_challenge_received(self, name, instruction, lang, prompts): """A keyboard-interactive auth challenge has been received This method is called when the server sends a keyboard-interactive authentication challenge. The return value should be a list of strings of the same length as the number of prompts provided if the challenge can be answered, or ``None`` to indicate that some other form of authentication should be attempted. If blocking operations need to be performed to determine the responses to authenticate with, this method may be defined as a coroutine. By default, this method will look for a challenge consisting of a single 'Password:' prompt, and call the method :meth:`password_auth_requested` to provide the response. It will also ignore challenges with no prompts (generally used to provide instructions). Any other form of challenge will cause this method to return ``None`` to move on to another authentication method. :param str name: The name of the challenge :param str instruction: Instructions to the user about how to respond to the challenge :param str lang: The language the challenge is in :param prompts: The challenges the user should respond to and whether or not the responses should be echoed when they are entered :type prompts: list of tuples of str and bool :returns: List of string responses to the challenge or ``None`` to move on to another authentication method """ return None # pragma: no cover asyncssh-1.11.1/asyncssh/compression.py000066400000000000000000000044121320320510200201610ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH compression handlers""" import zlib _cmp_algs = [] _cmp_params = {} _cmp_compressors = {} _cmp_decompressors = {} def _none(): """Compressor/decompressor for no compression.""" return None class _ZLibCompress: """Wrapper class to force a sync flush when compressing""" def __init__(self): self._comp = zlib.compressobj() def compress(self, data): """Compress data using zlib compression with sync flush""" return self._comp.compress(data) + self._comp.flush(zlib.Z_SYNC_FLUSH) def register_compression_alg(alg, compressor, decompressor, after_auth): """Register a compression algorithm""" _cmp_algs.append(alg) _cmp_params[alg] = after_auth _cmp_compressors[alg] = compressor _cmp_decompressors[alg] = decompressor def get_compression_algs(): """Return a list of available compression algorithms""" return _cmp_algs def get_compression_params(alg): """Get parameters of a compression algorithm This function returns whether or not a compression algorithm should be delayed until after authentication completes. """ return _cmp_params[alg] def get_compressor(alg): """Return an instance of a compressor This function returns an object that can be used for data compression. """ return _cmp_compressors[alg]() def get_decompressor(alg): """Return an instance of a decompressor This function returns an object that can be used for data decompression. """ return _cmp_decompressors[alg]() # pylint: disable=bad-whitespace register_compression_alg(b'zlib@openssh.com', _ZLibCompress, zlib.decompressobj, True) register_compression_alg(b'zlib', _ZLibCompress, zlib.decompressobj, False) register_compression_alg(b'none', _none, _none, False) asyncssh-1.11.1/asyncssh/connection.py000066400000000000000000005667251320320510200200030ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH connection handlers""" import asyncio import getpass import io import os import socket import sys import time from collections import OrderedDict from .agent import connect_agent, create_agent_listener from .auth import lookup_client_auth from .auth import get_server_auth_methods, lookup_server_auth from .auth_keys import read_authorized_keys from .channel import SSHClientChannel, SSHServerChannel from .channel import SSHTCPChannel, SSHUNIXChannel from .channel import SSHX11Channel, SSHAgentChannel from .cipher import get_encryption_algs, get_encryption_params, get_cipher from .client import SSHClient from .compression import get_compression_algs, get_compression_params from .compression import get_compressor, get_decompressor from .constants import DEFAULT_LANG from .constants import DISC_BY_APPLICATION, DISC_CONNECTION_LOST from .constants import DISC_KEY_EXCHANGE_FAILED, DISC_HOST_KEY_NOT_VERIFYABLE from .constants import DISC_MAC_ERROR, DISC_NO_MORE_AUTH_METHODS_AVAILABLE from .constants import DISC_PROTOCOL_ERROR, DISC_SERVICE_NOT_AVAILABLE from .constants import EXTENDED_DATA_STDERR from .constants import MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG from .constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, MSG_EXT_INFO from .constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_CONFIRMATION from .constants import MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_WINDOW_ADJUST from .constants import MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA from .constants import MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST from .constants import MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE from .constants import MSG_KEXINIT, MSG_NEWKEYS, MSG_KEX_FIRST, MSG_KEX_LAST from .constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE from .constants import MSG_USERAUTH_SUCCESS, MSG_USERAUTH_BANNER from .constants import MSG_USERAUTH_FIRST, MSG_USERAUTH_LAST from .constants import MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS from .constants import MSG_REQUEST_FAILURE from .constants import OPEN_ADMINISTRATIVELY_PROHIBITED, OPEN_CONNECT_FAILED from .constants import OPEN_UNKNOWN_CHANNEL_TYPE from .forward import SSHForwarder from .gss import GSSClient, GSSServer, GSSError from .kex import get_kex_algs, expand_kex_algs, get_kex from .known_hosts import match_known_hosts from .listener import SSHTCPClientListener, create_tcp_forward_listener from .listener import SSHUNIXClientListener, create_unix_forward_listener from .logging import logger from .mac import get_mac_algs, get_mac_params, get_mac from .misc import ChannelOpenError, DisconnectError, PasswordChangeRequired from .misc import async_context_manager, create_task, ip_address from .misc import load_default_keypairs, map_handler_name from .packet import Boolean, Byte, NameList, String, UInt32, UInt64 from .packet import PacketDecodeError, SSHPacket, SSHPacketHandler from .process import PIPE, SSHClientProcess, SSHServerProcess from .public_key import CERT_TYPE_HOST, CERT_TYPE_USER, KeyImportError from .public_key import decode_ssh_public_key, decode_ssh_certificate from .public_key import get_public_key_algs, get_certificate_algs from .public_key import get_x509_certificate_algs from .public_key import load_keypairs, load_certificates from .saslprep import saslprep, SASLPrepError from .server import SSHServer from .sftp import SFTPServer, start_sftp_client from .stream import SSHClientStreamSession, SSHServerStreamSession from .stream import SSHTCPStreamSession, SSHUNIXStreamSession from .stream import SSHReader, SSHWriter from .x11 import create_x11_client_listener, create_x11_server_listener # SSH default port _DEFAULT_PORT = 22 # SSH service names _USERAUTH_SERVICE = b'ssh-userauth' _CONNECTION_SERVICE = b'ssh-connection' # Default rekey parameters _DEFAULT_REKEY_BYTES = 1 << 30 # 1 GiB _DEFAULT_REKEY_SECONDS = 3600 # 1 hour # Default login timeout _DEFAULT_LOGIN_TIMEOUT = 120 # 2 minutes # Default channel parameters _DEFAULT_WINDOW = 2*1024*1024 # 2 MiB _DEFAULT_MAX_PKTSIZE = 32768 # 32 kiB # Default line editor parameters _DEFAULT_LINE_HISTORY = 1000 # 1000 lines def _validate_version(version): """Validate requested SSH version""" if version is (): from .version import __version__ version = b'AsyncSSH_' + __version__.encode('ascii') else: if isinstance(version, str): version = version.encode('ascii') # Version including 'SSH-2.0-' and CRLF must be 255 chars or less if len(version) > 245: raise ValueError('Version string is too long') for b in version: if b < 0x20 or b > 0x7e: raise ValueError('Version string must be printable ASCII') return version def _select_algs(alg_type, algs, possible_algs, none_value=None): """Select a set of allowed algorithms""" if algs == (): return possible_algs elif algs: result = [] for alg_str in algs: alg = alg_str.encode('ascii') if alg not in possible_algs: raise ValueError('%s is not a valid %s algorithm' % (alg_str, alg_type)) result.append(alg) return result elif none_value: return [none_value] else: raise ValueError('No %s algorithms selected' % alg_type) def _validate_algs(kex_algs, enc_algs, mac_algs, cmp_algs, sig_algs, allow_x509): """Validate requested algorithms""" kex_algs = _select_algs('key exchange', kex_algs, get_kex_algs()) enc_algs = _select_algs('encryption', enc_algs, get_encryption_algs()) mac_algs = _select_algs('MAC', mac_algs, get_mac_algs()) cmp_algs = _select_algs('compression', cmp_algs, get_compression_algs(), b'none') allowed_sig_algs = get_x509_certificate_algs() if allow_x509 else [] allowed_sig_algs = allowed_sig_algs + get_public_key_algs() sig_algs = _select_algs('signature', sig_algs, allowed_sig_algs) return kex_algs, enc_algs, mac_algs, cmp_algs, sig_algs class SSHConnection(SSHPacketHandler): """Parent class for SSH connections""" def __init__(self, protocol_factory, loop, version, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, rekey_bytes, rekey_seconds, server): self._protocol_factory = protocol_factory self._loop = loop self._transport = None self._peer_addr = None self._owner = None self._extra = {} self._server = server self._inpbuf = b'' self._packet = b'' self._pktlen = 0 self._version = version self._client_version = b'' self._server_version = b'' self._client_kexinit = b'' self._server_kexinit = b'' self._session_id = None self._send_seq = 0 self._send_cipher = None self._send_blocksize = 8 self._send_mac = None self._send_mode = None self._compressor = None self._compress_after_auth = False self._deferred_packets = [] self._recv_handler = self._recv_version self._recv_seq = 0 self._recv_cipher = None self._recv_blocksize = 8 self._recv_mac = None self._recv_macsize = 0 self._recv_mode = None self._decompressor = None self._decompress_after_auth = None self._next_recv_cipher = None self._next_recv_blocksize = 0 self._next_recv_mac = None self._next_recv_macsize = 0 self._next_recv_mode = None self._next_decompressor = None self._next_decompress_after_auth = None self._kex_algs = kex_algs self._enc_algs = encryption_algs self._mac_algs = mac_algs self._cmp_algs = compression_algs self._sig_algs = signature_algs self._kex = None self._kexinit_sent = False self._kex_complete = False self._ignore_first_kex = False self._gss = None self._gss_kex_auth = False self._gss_mic_auth = False self._rekey_bytes = rekey_bytes self._rekey_bytes_sent = 0 self._rekey_seconds = rekey_seconds self._rekey_time = time.time() + rekey_seconds self._enc_alg_cs = None self._enc_alg_sc = None self._mac_alg_cs = None self._mac_alg_sc = None self._cmp_alg_cs = None self._cmp_alg_sc = None self._can_send_ext_info = False self._extensions_sent = OrderedDict() self._server_sig_algs = () self._next_service = None self._agent = None self._auth = None self._auth_in_progress = False self._auth_complete = False self._auth_methods = [b'none'] self._auth_waiter = None self._username = None self._channels = {} self._next_recv_chan = 0 self._global_request_queue = [] self._global_request_waiters = [] self._local_listeners = {} self._x11_listener = None self._close_event = asyncio.Event(loop=loop) self._server_host_key_algs = [] def __enter__(self): """Allow SSHConnection to be used as a context manager""" return self def __exit__(self, *exc_info): """Automatically close the connection when used as a context manager""" try: self.close() except RuntimeError as exc: # pragma: no cover # There's a race in some cases between the close call here # and the code which shuts down the event loop. Since the # loop.is_closed() method is only in Python 3.4.2 and later, # catch and ignore the RuntimeError for now if this happens. if exc.args[0] == 'Event loop is closed': pass else: raise @asyncio.coroutine def __aenter__(self): """Allow SSHConnection to be used as an async context manager""" return self @asyncio.coroutine def __aexit__(self, *exc_info): """Wait for connection close when used as an async context manager""" self.__exit__() yield from self.wait_closed() def _cleanup(self, exc): """Clean up this connection""" for chan in list(self._channels.values()): chan.process_connection_close(exc) for listener in self._local_listeners.values(): listener.close() while self._global_request_waiters: self._process_global_response(MSG_REQUEST_FAILURE, SSHPacket(b'')) if self._owner: # pragma: no branch self._owner.connection_lost(exc) self._owner = None self._close_event.set() self._inpbuf = b'' self._recv_handler = None def _force_close(self, exc): """Force this connection to close immediately""" if not self._transport: return self._transport.abort() self._transport = None self._loop.call_soon(self._cleanup, exc) def _reap_task(self, task): """Collect result of an async task, reporting errors""" # pylint: disable=broad-except try: task.result() except asyncio.CancelledError: pass except DisconnectError as exc: self._send_disconnect(exc.code, exc.reason, exc.lang) self._force_close(exc) except Exception: self.internal_error() def create_task(self, coro): """Create an asynchronous task which catches and reports errors""" task = create_task(coro, loop=self._loop) task.add_done_callback(self._reap_task) return task def is_client(self): """Return if this is a client connection""" return not self._server def is_server(self): """Return if this is a server connection""" return self._server def get_owner(self): """Return the SSHClient or SSHServer which owns this connection""" return self._owner def get_hash_prefix(self): """Return the bytes used in calculating unique connection hashes This methods returns a packetized version of the client and server version and kexinit strings which is needed to perform key exchange hashes. """ return b''.join((String(self._client_version), String(self._server_version), String(self._client_kexinit), String(self._server_kexinit))) def connection_made(self, transport): """Handle a newly opened connection""" self._transport = transport sock = transport.get_extra_info('socket') sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) peername = transport.get_extra_info('peername') self._peer_addr = peername[0] if peername else None self._owner = self._protocol_factory() self._protocol_factory = None # pylint: disable=broad-except try: self._connection_made() self._owner.connection_made(self) self._send_version() except DisconnectError as exc: self._loop.call_soon(self.connection_lost, exc) except Exception: self._loop.call_soon(self.internal_error, sys.exc_info()) def connection_lost(self, exc=None): """Handle the closing of a connection""" if exc is None and self._transport: exc = DisconnectError(DISC_CONNECTION_LOST, 'Connection lost') self._force_close(exc) def internal_error(self, exc_info=None): """Handle a fatal error in connection processing""" if not exc_info: exc_info = sys.exc_info() logger.debug('Uncaught exception', exc_info=exc_info) self._force_close(exc_info[1]) def session_started(self): """Handle session start when opening tunneled SSH connection""" pass def data_received(self, data, datatype=None): """Handle incoming data on the connection""" # pylint: disable=unused-argument self._inpbuf += data # pylint: disable=broad-except try: while self._inpbuf and self._recv_handler(): pass except DisconnectError as exc: self._send_disconnect(exc.code, exc.reason, exc.lang) self._force_close(exc) except Exception: self.internal_error() def eof_received(self): """Handle an incoming end of file on the connection""" self.connection_lost(None) def pause_writing(self): """Handle a request from the transport to pause writing data""" # Do nothing with this for now pass # pragma: no cover def resume_writing(self): """Handle a request from the transport to resume writing data""" # Do nothing with this for now pass # pragma: no cover def add_channel(self, chan): """Add a new channel, returning its channel number""" if not self._transport: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'SSH connection closed') while self._next_recv_chan in self._channels: # pragma: no cover self._next_recv_chan = (self._next_recv_chan + 1) & 0xffffffff recv_chan = self._next_recv_chan self._next_recv_chan = (self._next_recv_chan + 1) & 0xffffffff self._channels[recv_chan] = chan return recv_chan def remove_channel(self, recv_chan): """Remove the channel with the specified channel number""" del self._channels[recv_chan] def get_gss_context(self): """Return the GSS context associated with this connection""" return self._gss def enable_gss_kex_auth(self): """Enable GSS key exchange authentication""" self._gss_kex_auth = True def _choose_alg(self, alg_type, local_algs, remote_algs): """Choose a common algorithm from the client & server lists This method returns the earliest algorithm on the client's list which is supported by the server. """ if self.is_client(): client_algs, server_algs = local_algs, remote_algs else: client_algs, server_algs = remote_algs, local_algs for alg in client_algs: if alg in server_algs: return alg raise DisconnectError(DISC_KEY_EXCHANGE_FAILED, 'No matching %s algorithm found' % alg_type) def _get_ext_info_kex_alg(self): """Return the kex alg to add if any to request extension info""" return [b'ext-info-c'] if self.is_client() else [] def _send(self, data): """Send data to the SSH connection""" if self._transport: self._transport.write(data) def _send_version(self): """Start the SSH handshake""" version = b'SSH-2.0-' + self._version if self.is_client(): self._client_version = version self._extra.update(client_version=version.decode('ascii')) else: self._server_version = version self._extra.update(server_version=version.decode('ascii')) self._send(version + b'\r\n') def _recv_version(self): """Receive and parse the remote SSH version""" idx = self._inpbuf.find(b'\n') if idx < 0: return False version = self._inpbuf[:idx] if version.endswith(b'\r'): version = version[:-1] self._inpbuf = self._inpbuf[idx+1:] if (version.startswith(b'SSH-2.0-') or (self.is_client() and version.startswith(b'SSH-1.99-'))): # Accept version 2.0, or 1.99 if we're a client if self.is_server(): self._client_version = version self._extra.update(client_version=version.decode('ascii')) else: self._server_version = version self._extra.update(server_version=version.decode('ascii')) self._send_kexinit() self._kexinit_sent = True self._recv_handler = self._recv_pkthdr elif self.is_client() and not version.startswith(b'SSH-'): # As a client, ignore the line if it doesn't appear to be a version pass else: # Otherwise, reject the unknown version self._force_close(DisconnectError(DISC_PROTOCOL_ERROR, 'Unknown SSH version')) return False return True def _recv_pkthdr(self): """Receive and parse an SSH packet header""" if len(self._inpbuf) < self._recv_blocksize: return False self._packet = self._inpbuf[:self._recv_blocksize] self._inpbuf = self._inpbuf[self._recv_blocksize:] pktlen = self._packet[:4] if self._recv_cipher: # pragma: no branch if self._recv_mode == 'chacha': nonce = UInt64(self._recv_seq) pktlen = self._recv_cipher.crypt_len(pktlen, nonce) elif self._recv_mode not in ('gcm', 'etm'): self._packet = self._recv_cipher.decrypt(self._packet) pktlen = self._packet[:4] self._pktlen = int.from_bytes(pktlen, 'big') self._recv_handler = self._recv_packet return True def _recv_packet(self): """Receive the remainder of an SSH packet and process it""" rem = 4 + self._pktlen + self._recv_macsize - self._recv_blocksize if len(self._inpbuf) < rem: return False rest = self._inpbuf[:rem-self._recv_macsize] if self._recv_mode in ('chacha', 'gcm'): packet = self._packet + rest mac = self._inpbuf[rem-self._recv_macsize:rem] hdr = packet[:4] packet = packet[4:] if self._recv_mode == 'chacha': nonce = UInt64(self._recv_seq) packet = self._recv_cipher.verify_and_decrypt(hdr, packet, nonce, mac) else: packet = self._recv_cipher.verify_and_decrypt(hdr, packet, mac) if not packet: raise DisconnectError(DISC_MAC_ERROR, 'MAC verification failed') elif self._recv_mode == 'etm': packet = self._packet + rest mac = self._inpbuf[rem-self._recv_macsize:rem] if not self._recv_mac.verify(self._recv_seq, packet, mac): raise DisconnectError(DISC_MAC_ERROR, 'MAC verification failed') packet = self._recv_cipher.decrypt(packet[4:]) else: if self._recv_cipher: rest = self._recv_cipher.decrypt(rest) packet = self._packet + rest mac = self._inpbuf[rem-self._recv_macsize:rem] if self._recv_mac: if not self._recv_mac.verify(self._recv_seq, packet, mac): raise DisconnectError(DISC_MAC_ERROR, 'MAC verification failed') packet = packet[4:] self._inpbuf = self._inpbuf[rem:] self._packet = b'' payload = packet[1:-packet[0]] if self._decompressor and (self._auth_complete or not self._decompress_after_auth): payload = self._decompressor.decompress(payload) try: packet = SSHPacket(payload) pkttype = packet.get_byte() if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST: if self._ignore_first_kex: # pragma: no cover self._ignore_first_kex = False processed = True else: processed = self._kex.process_packet(pkttype, packet) elif (self._auth and MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST): processed = self._auth.process_packet(pkttype, packet) else: processed = self.process_packet(pkttype, packet) except PacketDecodeError as exc: raise DisconnectError(DISC_PROTOCOL_ERROR, str(exc)) from None if not processed: self.send_packet(Byte(MSG_UNIMPLEMENTED), UInt32(self._recv_seq)) if self._transport: self._recv_seq = (self._recv_seq + 1) & 0xffffffff self._recv_handler = self._recv_pkthdr return True def send_packet(self, *args): """Send an SSH packet""" payload = b''.join(args) pkttype = payload[0] if (self._auth_complete and self._kex_complete and (self._rekey_bytes_sent >= self._rekey_bytes or time.monotonic() >= self._rekey_time)): self._send_kexinit() self._kexinit_sent = True if (((pkttype in {MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT} or pkttype > MSG_KEX_LAST) and not self._kex_complete) or (pkttype == MSG_USERAUTH_BANNER and not (self._auth_in_progress or self._auth_complete)) or (pkttype > MSG_USERAUTH_LAST and not self._auth_complete)): self._deferred_packets.append(payload) return # If we're encrypting and we have no data outstanding, insert an # ignore packet into the stream if self._send_cipher and payload[0] != MSG_IGNORE: self.send_packet(Byte(MSG_IGNORE), String(b'')) if self._compressor and (self._auth_complete or not self._compress_after_auth): payload = self._compressor.compress(payload) hdrlen = 1 if self._send_mode in ('chacha', 'gcm', 'etm') else 5 padlen = -(hdrlen + len(payload)) % self._send_blocksize if padlen < 4: padlen += self._send_blocksize packet = Byte(padlen) + payload + os.urandom(padlen) pktlen = len(packet) hdr = UInt32(pktlen) if self._send_mode == 'chacha': nonce = UInt64(self._send_seq) hdr = self._send_cipher.crypt_len(hdr, nonce) packet, mac = self._send_cipher.encrypt_and_sign(hdr, packet, nonce) packet = hdr + packet elif self._send_mode == 'gcm': packet, mac = self._send_cipher.encrypt_and_sign(hdr, packet) packet = hdr + packet elif self._send_mode == 'etm': packet = hdr + self._send_cipher.encrypt(packet) mac = self._send_mac.sign(self._send_seq, packet) else: packet = hdr + packet if self._send_mac: mac = self._send_mac.sign(self._send_seq, packet) else: mac = b'' if self._send_cipher: packet = self._send_cipher.encrypt(packet) self._send(packet + mac) self._send_seq = (self._send_seq + 1) & 0xffffffff if self._kex_complete: self._rekey_bytes_sent += pktlen def _send_deferred_packets(self): """Send packets deferred due to key exchange or auth""" deferred_packets = self._deferred_packets self._deferred_packets = [] for packet in deferred_packets: self.send_packet(packet) def _send_disconnect(self, code, reason, lang): """Send a disconnect packet""" self.send_packet(Byte(MSG_DISCONNECT), UInt32(code), String(reason), String(lang)) def _send_kexinit(self): """Start a key exchange""" self._kex_complete = False self._rekey_bytes_sent = 0 self._rekey_time = time.monotonic() + self._rekey_seconds gss_mechs = self._gss.mechs if self._gss else [] kex_algs = expand_kex_algs(self._kex_algs, gss_mechs, bool(self._server_host_key_algs)) cookie = os.urandom(16) kex_algs = NameList(kex_algs + self._get_ext_info_kex_alg()) host_key_algs = NameList(self._server_host_key_algs or [b'null']) enc_algs = NameList(self._enc_algs) mac_algs = NameList(self._mac_algs) cmp_algs = NameList(self._cmp_algs) langs = NameList([]) packet = b''.join((Byte(MSG_KEXINIT), cookie, kex_algs, host_key_algs, enc_algs, enc_algs, mac_algs, mac_algs, cmp_algs, cmp_algs, langs, langs, Boolean(False), UInt32(0))) if self.is_server(): self._server_kexinit = packet else: self._client_kexinit = packet self.send_packet(packet) def _send_ext_info(self): """Send extension information""" packet = b''.join((Byte(MSG_EXT_INFO), UInt32(len(self._extensions_sent)))) for name, value in self._extensions_sent.items(): packet += String(name) + String(value) self.send_packet(packet) def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" if not self._session_id: self._session_id = h enc_keysize_cs, enc_ivsize_cs, enc_blocksize_cs, mode_cs = \ get_encryption_params(self._enc_alg_cs) enc_keysize_sc, enc_ivsize_sc, enc_blocksize_sc, mode_sc = \ get_encryption_params(self._enc_alg_sc) if mode_cs in ('chacha', 'gcm'): mac_keysize_cs, mac_hashsize_cs = 0, 16 else: mac_keysize_cs, mac_hashsize_cs, etm_cs = \ get_mac_params(self._mac_alg_cs) if etm_cs: mode_cs = 'etm' if mode_sc in ('chacha', 'gcm'): mac_keysize_sc, mac_hashsize_sc = 0, 16 else: mac_keysize_sc, mac_hashsize_sc, etm_sc = \ get_mac_params(self._mac_alg_sc) if etm_sc: mode_sc = 'etm' cmp_after_auth_cs = get_compression_params(self._cmp_alg_cs) cmp_after_auth_sc = get_compression_params(self._cmp_alg_sc) iv_cs = self._kex.compute_key(k, h, b'A', self._session_id, enc_ivsize_cs) iv_sc = self._kex.compute_key(k, h, b'B', self._session_id, enc_ivsize_sc) enc_key_cs = self._kex.compute_key(k, h, b'C', self._session_id, enc_keysize_cs) enc_key_sc = self._kex.compute_key(k, h, b'D', self._session_id, enc_keysize_sc) mac_key_cs = self._kex.compute_key(k, h, b'E', self._session_id, mac_keysize_cs) mac_key_sc = self._kex.compute_key(k, h, b'F', self._session_id, mac_keysize_sc) self._kex = None next_cipher_cs = get_cipher(self._enc_alg_cs, enc_key_cs, iv_cs) next_cipher_sc = get_cipher(self._enc_alg_sc, enc_key_sc, iv_sc) if mode_cs in ('chacha', 'gcm'): self._mac_alg_cs = self._enc_alg_cs next_mac_cs = None else: next_mac_cs = get_mac(self._mac_alg_cs, mac_key_cs) if mode_sc in ('chacha', 'gcm'): self._mac_alg_sc = self._enc_alg_sc next_mac_sc = None else: next_mac_sc = get_mac(self._mac_alg_sc, mac_key_sc) self.send_packet(Byte(MSG_NEWKEYS)) if self.is_client(): self._send_cipher = next_cipher_cs self._send_blocksize = max(8, enc_blocksize_cs) self._send_mac = next_mac_cs self._send_mode = mode_cs self._compressor = get_compressor(self._cmp_alg_cs) self._compress_after_auth = cmp_after_auth_cs self._next_recv_cipher = next_cipher_sc self._next_recv_blocksize = max(8, enc_blocksize_sc) self._next_recv_mac = next_mac_sc self._next_recv_macsize = mac_hashsize_sc self._next_recv_mode = mode_sc self._next_decompressor = get_decompressor(self._cmp_alg_sc) self._next_decompress_after_auth = cmp_after_auth_sc self._extra.update( send_cipher=self._enc_alg_cs.decode('ascii'), send_mac=self._mac_alg_cs.decode('ascii'), send_compression=self._cmp_alg_cs.decode('ascii'), recv_cipher=self._enc_alg_sc.decode('ascii'), recv_mac=self._mac_alg_sc.decode('ascii'), recv_compression=self._cmp_alg_sc.decode('ascii')) else: self._send_cipher = next_cipher_sc self._send_blocksize = max(8, enc_blocksize_sc) self._send_mac = next_mac_sc self._send_mode = mode_sc self._compressor = get_compressor(self._cmp_alg_sc) self._compress_after_auth = cmp_after_auth_sc self._next_recv_cipher = next_cipher_cs self._next_recv_blocksize = max(8, enc_blocksize_cs) self._next_recv_mac = next_mac_cs self._next_recv_macsize = mac_hashsize_cs self._next_recv_mode = mode_cs self._next_decompressor = get_decompressor(self._cmp_alg_cs) self._next_decompress_after_auth = cmp_after_auth_cs self._extra.update( send_cipher=self._enc_alg_sc.decode('ascii'), send_mac=self._mac_alg_sc.decode('ascii'), send_compression=self._cmp_alg_sc.decode('ascii'), recv_cipher=self._enc_alg_cs.decode('ascii'), recv_mac=self._mac_alg_cs.decode('ascii'), recv_compression=self._cmp_alg_cs.decode('ascii')) self._next_service = _USERAUTH_SERVICE if self._can_send_ext_info: self._extensions_sent['server-sig-algs'] = \ b','.join(self._sig_algs) self._send_ext_info() self._kex_complete = True self._send_deferred_packets() def send_service_request(self, service): """Send a service request""" self._next_service = service self.send_packet(Byte(MSG_SERVICE_REQUEST), String(service)) def _get_userauth_request_packet(self, method, args): """Get packet data for a user authentication request""" return b''.join((Byte(MSG_USERAUTH_REQUEST), String(self._username), String(_CONNECTION_SERVICE), String(method)) + args) def get_userauth_request_data(self, method, *args): """Get signature data for a user authentication request""" return (String(self._session_id) + self._get_userauth_request_packet(method, args)) @asyncio.coroutine def send_userauth_request(self, method, *args, key=None): """Send a user authentication request""" packet = self._get_userauth_request_packet(method, args) if key: sig = key.sign(String(self._session_id) + packet) if asyncio.iscoroutine(sig): sig = yield from sig packet += String(sig) self.send_packet(packet) def send_userauth_failure(self, partial_success): """Send a user authentication failure response""" self._auth = None self.send_packet(Byte(MSG_USERAUTH_FAILURE), NameList(get_server_auth_methods(self)), Boolean(partial_success)) def send_userauth_success(self): """Send a user authentication success response""" self.send_packet(Byte(MSG_USERAUTH_SUCCESS)) self._auth = None self._auth_in_progress = False self._auth_complete = True self._extra.update(username=self._username) self._send_deferred_packets() # This method is only in SSHServerConnection # pylint: disable=no-member self._cancel_login_timer() def send_channel_open_confirmation(self, send_chan, recv_chan, recv_window, recv_pktsize, *result_args): """Send a channel open confirmation""" self.send_packet(Byte(MSG_CHANNEL_OPEN_CONFIRMATION), UInt32(send_chan), UInt32(recv_chan), UInt32(recv_window), UInt32(recv_pktsize), *result_args) def send_channel_open_failure(self, send_chan, code, reason, lang): """Send a channel open failure""" self.send_packet(Byte(MSG_CHANNEL_OPEN_FAILURE), UInt32(send_chan), UInt32(code), String(reason), String(lang)) @asyncio.coroutine def _make_global_request(self, request, *args): """Send a global request and wait for the response""" if not self._transport: return MSG_REQUEST_FAILURE, SSHPacket(b'') waiter = asyncio.Future(loop=self._loop) self._global_request_waiters.append(waiter) self.send_packet(Byte(MSG_GLOBAL_REQUEST), String(request), Boolean(True), *args) return (yield from waiter) def _report_global_response(self, result): """Report back the response to a previously issued global request""" _, _, want_reply = self._global_request_queue.pop(0) if want_reply: # pragma: no branch if result: response = b'' if result is True else result self.send_packet(Byte(MSG_REQUEST_SUCCESS), response) else: self.send_packet(Byte(MSG_REQUEST_FAILURE)) if self._global_request_queue: self._service_next_global_request() def _service_next_global_request(self): """Process next item on global request queue""" handler, packet, _ = self._global_request_queue[0] if callable(handler): handler(packet) else: self._report_global_response(False) def _connection_made(self): """Handle the opening of a new connection""" raise NotImplementedError def _process_disconnect(self, pkttype, packet): """Process a disconnect message""" # pylint: disable=unused-argument code = packet.get_uint32() reason = packet.get_string() lang = packet.get_string() packet.check_end() try: reason = reason.decode('utf-8') lang = lang.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid disconnect message') from None if code != DISC_BY_APPLICATION: exc = DisconnectError(code, reason, lang) else: exc = None self._force_close(exc) def _process_ignore(self, pkttype, packet): """Process an ignore message""" # pylint: disable=no-self-use,unused-argument _ = packet.get_string() # data packet.check_end() # Do nothing def _process_unimplemented(self, pkttype, packet): """Process an unimplemented message response""" # pylint: disable=no-self-use,unused-argument _ = packet.get_uint32() # seq packet.check_end() # Ignore this def _process_debug(self, pkttype, packet): """Process a debug message""" # pylint: disable=unused-argument always_display = packet.get_boolean() msg = packet.get_string() lang = packet.get_string() packet.check_end() try: msg = msg.decode('utf-8') lang = lang.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid debug message') from None self._owner.debug_msg_received(msg, lang, always_display) def _process_service_request(self, pkttype, packet): """Process a service request""" # pylint: disable=unused-argument service = packet.get_string() packet.check_end() if service == self._next_service: self._next_service = None self.send_packet(Byte(MSG_SERVICE_ACCEPT), String(service)) if (self.is_server() and # pragma: no branch service == _USERAUTH_SERVICE): self._auth_in_progress = True self._send_deferred_packets() else: raise DisconnectError(DISC_SERVICE_NOT_AVAILABLE, 'Unexpected service request received') def _process_service_accept(self, pkttype, packet): """Process a service accept response""" # pylint: disable=unused-argument service = packet.get_string() packet.check_end() if service == self._next_service: self._next_service = None if (self.is_client() and # pragma: no branch service == _USERAUTH_SERVICE): self._auth_in_progress = True # This method is only in SSHClientConnection # pylint: disable=no-member self.try_next_auth() else: raise DisconnectError(DISC_SERVICE_NOT_AVAILABLE, 'Unexpected service accept received') def _process_ext_info(self, pkttype, packet): """Process extension information""" # pylint: disable=unused-argument extensions = {} num_extensions = packet.get_uint32() for _ in range(num_extensions): name = packet.get_string() value = packet.get_string() extensions[name] = value packet.check_end() if self.is_client(): self._server_sig_algs = \ extensions.get(b'server-sig-algs').split(b',') def _process_kexinit(self, pkttype, packet): """Process a key exchange request""" # pylint: disable=unused-argument if self._kex: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Key exchange already in progress') _ = packet.get_bytes(16) # cookie peer_kex_algs = packet.get_namelist() peer_host_key_algs = packet.get_namelist() enc_algs_cs = packet.get_namelist() enc_algs_sc = packet.get_namelist() mac_algs_cs = packet.get_namelist() mac_algs_sc = packet.get_namelist() cmp_algs_cs = packet.get_namelist() cmp_algs_sc = packet.get_namelist() _ = packet.get_namelist() # lang_cs _ = packet.get_namelist() # lang_sc first_kex_follows = packet.get_boolean() _ = packet.get_uint32() # reserved packet.check_end() if self.is_server(): self._client_kexinit = packet.get_consumed_payload() if b'ext-info-c' in peer_kex_algs: self._can_send_ext_info = True else: self._server_kexinit = packet.get_consumed_payload() if self._kexinit_sent: self._kexinit_sent = False else: self._send_kexinit() if self._gss: self._gss.reset() gss_mechs = self._gss.mechs if self._gss else [] kex_algs = expand_kex_algs(self._kex_algs, gss_mechs, bool(self._server_host_key_algs)) kex_alg = self._choose_alg('key exchange', kex_algs, peer_kex_algs) self._kex = get_kex(self, kex_alg) self._ignore_first_kex = (first_kex_follows and self._kex.algorithm != peer_kex_algs[0]) if self.is_server(): # This method is only in SSHServerConnection # pylint: disable=no-member if (not self._choose_server_host_key(peer_host_key_algs) and not kex_alg.startswith(b'gss-')): raise DisconnectError(DISC_KEY_EXCHANGE_FAILED, 'Unable ' 'to find compatible server host key') self._enc_alg_cs = self._choose_alg('encryption', self._enc_algs, enc_algs_cs) self._enc_alg_sc = self._choose_alg('encryption', self._enc_algs, enc_algs_sc) self._mac_alg_cs = self._choose_alg('MAC', self._mac_algs, mac_algs_cs) self._mac_alg_sc = self._choose_alg('MAC', self._mac_algs, mac_algs_sc) self._cmp_alg_cs = self._choose_alg('compression', self._cmp_algs, cmp_algs_cs) self._cmp_alg_sc = self._choose_alg('compression', self._cmp_algs, cmp_algs_sc) if self.is_client(): self._kex.start() def _process_newkeys(self, pkttype, packet): """Process a new keys message, finishing a key exchange""" # pylint: disable=unused-argument packet.check_end() if self._next_recv_cipher: self._recv_cipher = self._next_recv_cipher self._recv_blocksize = self._next_recv_blocksize self._recv_mac = self._next_recv_mac self._recv_mode = self._next_recv_mode self._recv_macsize = self._next_recv_macsize self._decompressor = self._next_decompressor self._decompress_after_auth = self._next_decompress_after_auth self._next_recv_cipher = None else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'New keys not negotiated') if self.is_client() and not (self._auth_in_progress or self._auth_complete): self.send_service_request(_USERAUTH_SERVICE) def _process_userauth_request(self, pkttype, packet): """Process a user authentication request""" # pylint: disable=unused-argument username = packet.get_string() service = packet.get_string() method = packet.get_string() if service != _CONNECTION_SERVICE: raise DisconnectError(DISC_SERVICE_NOT_AVAILABLE, 'Unexpected service in auth request') try: username = saslprep(username.decode('utf-8')) except (UnicodeDecodeError, SASLPrepError): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid auth request message') from None if self.is_client(): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected userauth request') elif self._auth_complete: # Silent ignore requests if we're already authenticated pass else: if username != self._username: self._username = username if not self._owner.begin_auth(username): self.send_userauth_success() return if self._auth: self._auth.cancel() self._auth = lookup_server_auth(self, self._username, method, packet) def _process_userauth_failure(self, pkttype, packet): """Process a user authentication failure response""" # pylint: disable=unused-argument self._auth_methods = packet.get_namelist() partial_success = packet.get_boolean() packet.check_end() if self.is_client() and self._auth: if partial_success: # pragma: no cover # Partial success not implemented yet self._auth.auth_succeeded() else: self._auth.auth_failed() # This method is only in SSHClientConnection # pylint: disable=no-member self.try_next_auth() else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected userauth response') def _process_userauth_success(self, pkttype, packet): """Process a user authentication success response""" # pylint: disable=unused-argument packet.check_end() if self.is_client() and self._auth: self._auth.auth_succeeded() self._auth.cancel() self._auth = None self._auth_in_progress = False self._auth_complete = True if self._agent: self._agent.close() self._agent = None self._extra.update(username=self._username) self._send_deferred_packets() self._owner.auth_completed() if not self._auth_waiter.cancelled(): # pragma: no branch self._auth_waiter.set_result(None) self._auth_waiter = None else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected userauth response') def _process_userauth_banner(self, pkttype, packet): """Process a user authentication banner message""" # pylint: disable=unused-argument msg = packet.get_string() lang = packet.get_string() packet.check_end() try: msg = msg.decode('utf-8') lang = lang.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid userauth banner') from None if self.is_client(): self._owner.auth_banner_received(msg, lang) else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected userauth banner') def _process_global_request(self, pkttype, packet): """Process a global request""" # pylint: disable=unused-argument request = packet.get_string() want_reply = packet.get_boolean() try: request = request.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid global request') from None name = '_process_' + map_handler_name(request) + '_global_request' handler = getattr(self, name, None) self._global_request_queue.append((handler, packet, want_reply)) if len(self._global_request_queue) == 1: self._service_next_global_request() def _process_global_response(self, pkttype, packet): """Process a global response""" # pylint: disable=unused-argument if self._global_request_waiters: waiter = self._global_request_waiters.pop(0) if not waiter.cancelled(): # pragma: no branch waiter.set_result((pkttype, packet)) else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected global response') def _process_channel_open(self, pkttype, packet): """Process a channel open request""" # pylint: disable=unused-argument chantype = packet.get_string() send_chan = packet.get_uint32() send_window = packet.get_uint32() send_pktsize = packet.get_uint32() try: chantype = chantype.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid channel open request') from None try: name = '_process_' + map_handler_name(chantype) + '_open' handler = getattr(self, name, None) if callable(handler): chan, session = handler(packet) chan.process_open(send_chan, send_window, send_pktsize, session) else: raise ChannelOpenError(OPEN_UNKNOWN_CHANNEL_TYPE, 'Unknown channel type') except ChannelOpenError as exc: self.send_channel_open_failure(send_chan, exc.code, exc.reason, exc.lang) def _process_channel_open_confirmation(self, pkttype, packet): """Process a channel open confirmation response""" # pylint: disable=unused-argument recv_chan = packet.get_uint32() send_chan = packet.get_uint32() send_window = packet.get_uint32() send_pktsize = packet.get_uint32() chan = self._channels.get(recv_chan) if chan: chan.process_open_confirmation(send_chan, send_window, send_pktsize, packet) else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid channel number') def _process_channel_open_failure(self, pkttype, packet): """Process a channel open failure response""" # pylint: disable=unused-argument recv_chan = packet.get_uint32() code = packet.get_uint32() reason = packet.get_string() lang = packet.get_string() packet.check_end() try: reason = reason.decode('utf-8') lang = lang.decode('ascii') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid channel open failure') from None chan = self._channels.get(recv_chan) if chan: chan.process_open_failure(code, reason, lang) else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid channel number') def _process_channel_msg(self, pkttype, packet): """Process a channel-specific message""" recv_chan = packet.get_uint32() chan = self._channels.get(recv_chan) if chan: chan.process_packet(pkttype, packet) else: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid channel number') packet_handlers = { MSG_DISCONNECT: _process_disconnect, MSG_IGNORE: _process_ignore, MSG_UNIMPLEMENTED: _process_unimplemented, MSG_DEBUG: _process_debug, MSG_SERVICE_REQUEST: _process_service_request, MSG_SERVICE_ACCEPT: _process_service_accept, MSG_EXT_INFO: _process_ext_info, MSG_KEXINIT: _process_kexinit, MSG_NEWKEYS: _process_newkeys, MSG_USERAUTH_REQUEST: _process_userauth_request, MSG_USERAUTH_FAILURE: _process_userauth_failure, MSG_USERAUTH_SUCCESS: _process_userauth_success, MSG_USERAUTH_BANNER: _process_userauth_banner, MSG_GLOBAL_REQUEST: _process_global_request, MSG_REQUEST_SUCCESS: _process_global_response, MSG_REQUEST_FAILURE: _process_global_response, MSG_CHANNEL_OPEN: _process_channel_open, MSG_CHANNEL_OPEN_CONFIRMATION: _process_channel_open_confirmation, MSG_CHANNEL_OPEN_FAILURE: _process_channel_open_failure, MSG_CHANNEL_WINDOW_ADJUST: _process_channel_msg, MSG_CHANNEL_DATA: _process_channel_msg, MSG_CHANNEL_EXTENDED_DATA: _process_channel_msg, MSG_CHANNEL_EOF: _process_channel_msg, MSG_CHANNEL_CLOSE: _process_channel_msg, MSG_CHANNEL_REQUEST: _process_channel_msg, MSG_CHANNEL_SUCCESS: _process_channel_msg, MSG_CHANNEL_FAILURE: _process_channel_msg } def abort(self): """Forcibly close the SSH connection This method closes the SSH connection immediately, without waiting for pending operations to complete and wihtout sending an explicit SSH disconnect message. Buffered data waiting to be sent will be lost and no more data will be received. When the the connection is closed, :meth:`connection_lost() ` on the associated :class:`SSHClient` object will be called with the value ``None``. """ self._force_close(None) def close(self): """Cleanly close the SSH connection This method calls :meth:`disconnect` with the reason set to indicate that the connection was closed explicitly by the application. """ self.disconnect(DISC_BY_APPLICATION, 'Disconnected by application') @asyncio.coroutine def wait_closed(self): """Wait for this connection to close This method is a coroutine which can be called to block until this connection has finished closing. """ yield from self._close_event.wait() def disconnect(self, code, reason, lang=DEFAULT_LANG): """Disconnect the SSH connection This method sends a disconnect message and closes the SSH connection after buffered data waiting to be written has been sent. No more data will be received. When the connection is fully closed, :meth:`connection_lost() ` on the associated :class:`SSHClient` or :class:`SSHServer` object will be called with the value ``None``. :param int code: The reason for the disconnect, from :ref:`disconnect reason codes ` :param str reason: A human readable reason for the disconnect :param str lang: The language the reason is in """ if not self._transport: return for chan in list(self._channels.values()): chan.close() self._send_disconnect(code, reason, lang) self._transport.close() self._transport = None self._loop.call_soon(self._cleanup, None) def get_extra_info(self, name, default=None): """Get additional information about the connection This method returns extra information about the connection once it is established. Supported values include everything supported by a socket transport plus: | username | client_version | server_version | send_cipher | send_mac | send_compression | recv_cipher | recv_mac | recv_compression See :meth:`get_extra_info() ` in :class:`asyncio.BaseTransport` for more information. """ return self._extra.get(name, self._transport.get_extra_info(name, default) if self._transport else default) def send_debug(self, msg, lang=DEFAULT_LANG, always_display=False): """Send a debug message on this connection This method can be called to send a debug message to the other end of the connection. :param str msg: The debug message to send :param str lang: The language the message is in :param bool always_display: Whether or not to display the message """ self.send_packet(Byte(MSG_DEBUG), Boolean(always_display), String(msg), String(lang)) def create_tcp_channel(self, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH TCP channel for a new direct TCP connection This method can be called by :meth:`connection_requested() ` to create an :class:`SSHTCPChannel` with the desired encoding, window, and max packet size for a newly created SSH direct connection. :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection. This defaults to ``None``, allowing the application to send and receive raw bytes. :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :returns: :class:`SSHTCPChannel` """ return SSHTCPChannel(self, self._loop, encoding, window, max_pktsize) def create_unix_channel(self, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH UNIX channel for a new direct UNIX domain connection This method can be called by :meth:`unix_connection_requested() ` to create an :class:`SSHUNIXChannel` with the desired encoding, window, and max packet size for a newly created SSH direct UNIX domain socket connection. :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection. This defaults to ``None``, allowing the application to send and receive raw bytes. :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :returns: :class:`SSHUNIXChannel` """ return SSHUNIXChannel(self, self._loop, encoding, window, max_pktsize) def create_x11_channel(self, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH X11 channel to use in X11 forwarding""" return SSHX11Channel(self, self._loop, encoding, window, max_pktsize) def create_agent_channel(self, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH agent channel to use in agent forwarding""" return SSHAgentChannel(self, self._loop, encoding, window, max_pktsize) @asyncio.coroutine def create_connection(self, session_factory, remote_host, remote_port, orig_host='', orig_port=0, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH direct or forwarded TCP connection""" raise NotImplementedError @asyncio.coroutine def create_unix_connection(self, session_factory, remote_path, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH direct or forwarded UNIX domain socket connection""" raise NotImplementedError @asyncio.coroutine def forward_connection(self, dest_host, dest_port): """Forward a tunneled TCP connection This method is a coroutine which can be returned by a ``session_factory`` to forward connections tunneled over SSH to the specified destination host and port. :param str dest_host: The hostname or address to forward the connections to :param int dest_port: The port number to forward the connections to :returns: :class:`SSHTCPSession` """ try: if dest_host == '': dest_host = None _, peer = yield from self._loop.create_connection(SSHForwarder, dest_host, dest_port) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHForwarder(peer) @asyncio.coroutine def forward_unix_connection(self, dest_path): """Forward a tunneled UNIX domain socket connection This method is a coroutine which can be returned by a ``session_factory`` to forward connections tunneled over SSH to the specified destination path. :param str dest_path: The path to forward the connection to :returns: :class:`SSHUNIXSession` """ try: _, peer = \ yield from self._loop.create_unix_connection(SSHForwarder, dest_path) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHForwarder(peer) @asyncio.coroutine def forward_local_port(self, listen_host, listen_port, dest_host, dest_port): """Set up local port forwarding This method is a coroutine which attempts to set up port forwarding from a local listening port to a remote host and port via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. :param str listen_host: The hostname or address on the local host to listen on :param int listen_port: The port number on the local host to listen on :param str dest_host: The hostname or address to forward the connections to :param int dest_port: The port number to forward the connections to :returns: :class:`SSHListener` :raises: :exc:`OSError` if the listener can't be opened """ @asyncio.coroutine def tunnel_connection(session_factory, orig_host, orig_port): """Forward a local connection over SSH""" return (yield from self.create_connection(session_factory, dest_host, dest_port, orig_host, orig_port)) return (yield from create_tcp_forward_listener(self, self._loop, tunnel_connection, listen_host, listen_port)) @asyncio.coroutine def forward_local_path(self, listen_path, dest_path): """Set up local UNIX domain socket forwarding This method is a coroutine which attempts to set up UNIX domain socket forwarding from a local listening path to a remote path via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the UNIX domain socket forwarding. :param str listen_path: The path on the local host to listen on :param str dest_path: The path on the remote host to forward the connections to :returns: :class:`SSHListener` :raises: :exc:`OSError` if the listener can't be opened """ @asyncio.coroutine def tunnel_connection(session_factory): """Forward a local connection over SSH""" return (yield from self.create_unix_connection(session_factory, dest_path)) return (yield from create_unix_forward_listener(self, self._loop, tunnel_connection, listen_path)) class SSHClientConnection(SSHConnection): """SSH client connection This class represents an SSH client connection. Once authentication is successful on a connection, new client sessions can be opened by calling :meth:`create_session`. Direct TCP connections can be opened by calling :meth:`create_connection`. Remote listeners for forwarded TCP connections can be opened by calling :meth:`create_server`. Direct UNIX domain socket connections can be opened by calling :meth:`create_unix_connection`. Remote listeners for forwarded UNIX domain socket connections can be opened by calling :meth:`create_unix_server`. TCP port forwarding can be set up by calling :meth:`forward_local_port` or :meth:`forward_remote_port`. UNIX domain socket forwarding can be set up by calling :meth:`forward_local_path` or :meth:`forward_remote_path`. """ def __init__(self, client_factory, loop, client_version, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, rekey_bytes, rekey_seconds, host, port, known_hosts, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, username, password, client_keys, gss_host, gss_delegate_creds, agent, agent_path, auth_waiter): super().__init__(client_factory, loop, client_version, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, rekey_bytes, rekey_seconds, server=False) self._host = host self._port = port if port != _DEFAULT_PORT else None self._known_hosts = known_hosts self._x509_trusted_certs = x509_trusted_certs self._x509_trusted_cert_paths = x509_trusted_cert_paths self._x509_purposes = x509_purposes self._username = saslprep(username) self._password = password self._client_keys = client_keys self._agent = agent self._agent_path = agent_path self._auth_waiter = auth_waiter if gss_host: try: self._gss = GSSClient(gss_host, gss_delegate_creds) self._gss_mic_auth = True except GSSError: pass self._server_host_keys = set() self._server_ca_keys = set() self._revoked_server_keys = set() self._x509_revoked_certs = [] self._x509_trusted_subjects = [] self._x509_revoked_subjects = [] self._kbdint_password_auth = False self._remote_listeners = {} self._dynamic_remote_listeners = {} def _connection_made(self): """Handle the opening of a new connection""" if self._known_hosts is None: self._server_host_keys = None self._server_ca_keys = None else: if not self._known_hosts: self._known_hosts = os.path.join(os.path.expanduser('~'), '.ssh', 'known_hosts') known_hosts = match_known_hosts(self._known_hosts, self._host, self._peer_addr, self._port) server_host_keys, server_ca_keys, revoked_server_keys, \ server_x509_certs, revoked_x509_certs, \ server_x509_subjects, revoked_x509_subjects = known_hosts self._server_host_keys = set() self._server_host_key_algs = [] for key in server_host_keys: self._server_host_keys.add(key) if key.algorithm not in self._server_host_key_algs: self._server_host_key_algs.extend(key.sig_algorithms) if server_ca_keys: self._server_host_key_algs = \ get_certificate_algs() + self._server_host_key_algs self._server_ca_keys = set(server_ca_keys) self._revoked_server_keys = set(revoked_server_keys) if self._x509_trusted_certs is not None: self._x509_trusted_certs = list(self._x509_trusted_certs) self._x509_trusted_certs.extend(server_x509_certs) if self._x509_trusted_certs or self._x509_trusted_cert_paths: self._server_host_key_algs = \ get_x509_certificate_algs() + self._server_host_key_algs self._x509_revoked_certs = set(revoked_x509_certs) self._x509_trusted_subjects = server_x509_subjects self._x509_revoked_subjects = revoked_x509_subjects if not self._server_host_key_algs: if self._known_hosts is None: self._server_host_key_algs = (get_certificate_algs() + get_public_key_algs()) if self._x509_trusted_certs is not None: self._server_host_key_algs = \ get_x509_certificate_algs() + self._server_host_key_algs elif self._gss: self._server_host_key_algs = [b'null'] else: raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'No trusted server host keys available') def _cleanup(self, exc): """Clean up this client connection""" if self._agent: self._agent.close() self._agent = None if self._remote_listeners: for listener in self._remote_listeners.values(): listener.close() self._remote_listeners = {} self._dynamic_remote_listeners = {} if self._auth_waiter: if not self._auth_waiter.cancelled(): # pragma: no branch self._auth_waiter.set_exception(exc) self._auth_waiter = None super()._cleanup(exc) def _validate_server_openssh_certificate(self, cert): """Validate the server's OpenSSH certificate""" if cert.signing_key in self._revoked_server_keys: raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'Revoked server CA key') if self._server_ca_keys is not None and \ cert.signing_key not in self._server_ca_keys: raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'Untrusted server CA key') try: cert.validate(CERT_TYPE_HOST, self._host) except ValueError as exc: raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, str(exc)) from None return cert.key def _validate_server_x509_certificate_chain(self, cert): """Validate the server's X.509 certificate""" if (self._x509_revoked_subjects and any(pattern.matches(cert.subject) for pattern in self._x509_revoked_subjects)): raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'Revoked server X.509 subject name') if (self._x509_trusted_subjects and not any(pattern.matches(cert.subject) for pattern in self._x509_trusted_subjects)): raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'Untrusted server X.509 subject name') try: # Only validate hostname against X.509 certificate host # principals when there are no X.509 trusted subject # entries matched in known_hosts. if self._x509_trusted_subjects: host_principal = None else: host_principal = self._host cert.validate_chain(self._x509_trusted_certs, self._x509_trusted_cert_paths, self._x509_revoked_certs, self._x509_purposes, host_principal=host_principal) except ValueError as exc: raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, str(exc)) from None return cert.key def validate_server_host_key(self, data): """Validate and return the server's host key""" try: cert = decode_ssh_certificate(data) except KeyImportError: pass else: if cert.is_x509_chain: return self._validate_server_x509_certificate_chain(cert) else: return self._validate_server_openssh_certificate(cert) try: key = decode_ssh_public_key(data) except KeyImportError: pass else: if key in self._revoked_server_keys: raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'Revoked server host key') if self._server_host_keys is not None and \ key not in self._server_host_keys: raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'Untrusted server host key') return key raise DisconnectError(DISC_HOST_KEY_NOT_VERIFYABLE, 'Unable to decode server host key') def try_next_auth(self): """Attempt client authentication using the next compatible method""" if self._auth: self._auth.cancel() self._auth = None while self._auth_methods: method = self._auth_methods.pop(0) self._auth = lookup_client_auth(self, method) if self._auth: return self._force_close(DisconnectError(DISC_NO_MORE_AUTH_METHODS_AVAILABLE, 'Permission denied')) def gss_kex_auth_requested(self): """Return whether to allow GSS key exchange authentication or not""" if self._gss_kex_auth: self._gss_kex_auth = False return True else: return False def gss_mic_auth_requested(self): """Return whether to allow GSS MIC authentication or not""" if self._gss_mic_auth: self._gss_mic_auth = False return True else: return False @asyncio.coroutine def public_key_auth_requested(self): """Return a client key pair to authenticate with""" while True: if not self._client_keys: result = self._owner.public_key_auth_requested() if asyncio.iscoroutine(result): result = yield from result if not result: return None self._client_keys = load_keypairs(result) keypair = self._client_keys.pop(0) if self._server_sig_algs: for alg in keypair.sig_algorithms: if alg in self._sig_algs and alg in self._server_sig_algs: keypair.set_sig_algorithm(alg) return keypair if keypair.sig_algorithms[-1] in self._sig_algs: return keypair @asyncio.coroutine def password_auth_requested(self): """Return a password to authenticate with""" # Only allow password auth if the connection supports encryption # and a MAC -- Disable this for now: we don't allow null ciphers/macs # # if (not self._send_cipher or # (not self._send_mac and # self._send_mode not in ('chacha', 'gcm'))): # return None if self._password is not None: result = self._password self._password = None else: result = self._owner.password_auth_requested() if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def password_change_requested(self, prompt, lang): """Return a password to authenticate with and what to change it to""" result = self._owner.password_change_requested(prompt, lang) if asyncio.iscoroutine(result): result = yield from result return result def password_changed(self): """Report a successful password change""" self._owner.password_changed() def password_change_failed(self): """Report a failed password change""" self._owner.password_change_failed() @asyncio.coroutine def kbdint_auth_requested(self): """Return the list of supported keyboard-interactive auth methods If keyboard-interactive auth is not supported in the client but a password was provided when the connection was opened, this will allow sending the password via keyboard-interactive auth. """ # Only allow password auth if the connection supports encryption # and a MAC -- Disable this for now: we don't allow null ciphers/macs # # if (not self._send_cipher or # (not self._send_mac and # self._send_mode not in ('chacha', 'gcm'))): # return None result = self._owner.kbdint_auth_requested() if asyncio.iscoroutine(result): result = yield from result if result is NotImplemented: if self._password is not None and not self._kbdint_password_auth: self._kbdint_password_auth = True result = '' else: result = None return result @asyncio.coroutine def kbdint_challenge_received(self, name, instructions, lang, prompts): """Return responses to a keyboard-interactive auth challenge""" if self._kbdint_password_auth: if len(prompts) == 0: # Silently drop any empty challenges used to print messages result = [] elif len(prompts) == 1 and 'password' in prompts[0][0].lower(): password = yield from self.password_auth_requested() result = [password] if password is not None else None else: result = None else: result = self._owner.kbdint_challenge_received(name, instructions, lang, prompts) if asyncio.iscoroutine(result): result = yield from result return result def _process_session_open(self, packet): """Process an inbound session open request These requests are disallowed on an SSH client. """ # pylint: disable=no-self-use,unused-argument raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Session open forbidden on client') def _process_direct_tcpip_open(self, packet): """Process an inbound direct TCP/IP channel open request These requests are disallowed on an SSH client. """ # pylint: disable=no-self-use,unused-argument raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Direct TCP/IP open forbidden on client') def _process_forwarded_tcpip_open(self, packet): """Process an inbound forwarded TCP/IP channel open request""" dest_host = packet.get_string() dest_port = packet.get_uint32() orig_host = packet.get_string() orig_port = packet.get_uint32() packet.check_end() try: dest_host = dest_host.decode('utf-8') orig_host = orig_host.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid forwarded ' 'TCP/IP channel open request') from None # Some buggy servers send back a port of ``0`` instead of the actual # listening port when reporting connections which arrive on a listener # set up on a dynamic port. This lookup attempts to work around that. listener = (self._remote_listeners.get((dest_host, dest_port)) or self._dynamic_remote_listeners.get(dest_host)) if listener: return listener.process_connection(orig_host, orig_port) else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'No such listener') @asyncio.coroutine def close_client_tcp_listener(self, listen_host, listen_port): """Close a remote TCP/IP listener""" yield from self._make_global_request( b'cancel-tcpip-forward', String(listen_host), UInt32(listen_port)) listener = self._remote_listeners.get((listen_host, listen_port)) if listener: if self._dynamic_remote_listeners.get(listen_host) == listener: del self._dynamic_remote_listeners[listen_host] del self._remote_listeners[listen_host, listen_port] def _process_direct_streamlocal_at_openssh_dot_com_open(self, packet): """Process an inbound direct UNIX domain channel open request These requests are disallowed on an SSH client. """ # pylint: disable=no-self-use,unused-argument raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Direct UNIX domain socket open ' 'forbidden on client') def _process_forwarded_streamlocal_at_openssh_dot_com_open(self, packet): """Process an inbound forwarded UNIX domain channel open request""" dest_path = packet.get_string() _ = packet.get_string() # reserved packet.check_end() try: dest_path = dest_path.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid forwarded ' 'UNIX domain channel open request') from None listener = self._remote_listeners.get(dest_path) if listener: return listener.process_connection() else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'No such listener') @asyncio.coroutine def close_client_unix_listener(self, listen_path): """Close a remote UNIX domain socket listener""" yield from self._make_global_request( b'cancel-streamlocal-forward@openssh.com', String(listen_path)) if listen_path in self._remote_listeners: del self._remote_listeners[listen_path] def _process_x11_open(self, packet): """Process an inbound X11 channel open request""" orig_host = packet.get_string() orig_port = packet.get_uint32() packet.check_end() if self._x11_listener: chan = self.create_x11_channel() chan.set_inbound_peer_names(orig_host, orig_port) return chan, self._x11_listener.forward_connection() else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'X11 forwarding disabled') def _process_auth_agent_at_openssh_dot_com_open(self, packet): """Process an inbound auth agent channel open request""" packet.check_end() if self._agent_path: return (self.create_unix_channel(), self.forward_unix_connection(self._agent_path)) else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Auth agent forwarding disabled') @asyncio.coroutine def attach_x11_listener(self, chan, display, auth_path, single_connection): """Attach a channel to a local X11 display""" if not display: display = os.environ.get('DISPLAY') if not display: raise ValueError('X11 display not set') if not self._x11_listener: self._x11_listener = yield from create_x11_client_listener( self._loop, display, auth_path) return self._x11_listener.attach(display, chan, single_connection) def detach_x11_listener(self, chan): """Detach a session from a local X11 listener""" if self._x11_listener: if self._x11_listener.detach(chan): self._x11_listener = None @asyncio.coroutine def create_session(self, session_factory, command=None, *, subsystem=None, env={}, term_type=None, term_size=None, term_modes={}, x11_forwarding=False, x11_display=None, x11_auth_path=None, x11_single_connection=False, encoding='utf-8', window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH client session This method is a coroutine which can be called to create an SSH client session used to execute a command, start a subsystem such as sftp, or if no command or subsystem is specified run an interactive shell. Optional arguments allow terminal and environment information to be provided. By default, this class expects string data in its send and receive functions, which it encodes on the SSH connection in UTF-8 (ISO 10646) format. An optional encoding argument can be passed in to select a different encoding, or ``None`` can be passed in if the application wishes to send and receive raw bytes. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param callable session_factory: A callable which returns an :class:`SSHClientSession` object that will be created to handle activity on this session :param str command: (optional) The remote command to execute. By default, an interactive shell is started if no command or subsystem is provided. :param str subsystem: (optional) The name of a remote subsystem to start up :param dictionary env: (optional) The set of environment variables to set for this session. Keys and values passed in here will be converted to Unicode strings encoded as UTF-8 (ISO 10646) for transmission. .. note:: Many SSH servers restrict which environment variables a client is allowed to set. The server's configuration may need to be edited before environment variables can be successfully set in the remote environment. :param str term_type: (optional) The terminal type to set for this session. If this is not set, a pseudo-terminal will not be requested for this session. :param term_size: (optional) The terminal width and height in characters and optionally the width and height in pixels :param term_modes: (optional) POSIX terminal modes to set for this session, where keys are taken from :ref:`POSIX terminal modes ` with values defined in section 8 of :rfc:`RFC 4254 <4254#section-8>`. :param bool x11_forwarding: (optional) Whether or not to request X11 forwarding for this session, defaulting to ``False`` :param str x11_display: (optional) The display that X11 connections should be forwarded to, defaulting to the value in the environment variable ``DISPLAY`` :param str x11_auth_path: (optional) The path to the Xauthority file to read X11 authentication data from, defaulting to the value in the environment variable ``XAUTHORITY`` or the file ``.Xauthority`` in the user's home directory if that's not set :param bool x11_single_connection: (optional) Whether or not to limit X11 forwarding to a single connection, defaulting to ``False`` :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :type term_size: *tuple of 2 or 4 integers* :returns: an :class:`SSHClientChannel` and :class:`SSHClientSession` :raises: :exc:`ChannelOpenError` if the session can't be opened """ chan = SSHClientChannel(self, self._loop, encoding, window, max_pktsize) return (yield from chan.create(session_factory, command, subsystem, env, term_type, term_size, term_modes, x11_forwarding, x11_display, x11_auth_path, x11_single_connection, bool(self._agent_path))) @asyncio.coroutine def open_session(self, *args, **kwargs): """Open an SSH client session This method is a coroutine wrapper around :meth:`create_session` designed to provide a "high-level" stream interface for creating an SSH client session. Instead of taking a ``session_factory`` argument for constructing an object which will handle activity on the session via callbacks, it returns an :class:`SSHWriter` and two :class:`SSHReader` objects representing stdin, stdout, and stderr which can be used to perform I/O on the session. With the exception of ``session_factory``, all of the arguments to :meth:`create_session` are supported and have the same meaning. """ chan, session = yield from self.create_session(SSHClientStreamSession, *args, **kwargs) return (SSHWriter(session, chan), SSHReader(session, chan), SSHReader(session, chan, EXTENDED_DATA_STDERR)) # pylint: disable=redefined-builtin @async_context_manager def create_process(self, *args, bufsize=io.DEFAULT_BUFFER_SIZE, input=None, stdin=PIPE, stdout=PIPE, stderr=PIPE, **kwargs): """Create a process on the remote system This method is a coroutine wrapper around :meth:`create_session` which can be used to execute a command, start a subsystem, or start an interactive shell, optionally redirecting stdin, stdout, and stderr to and from files or pipes attached to other local and remote processes. By default, the stdin, stdout, and stderr arguments default to the special value ``PIPE`` which means that they can be read and written interactively via stream objects which are members of the :class:`SSHClientProcess` object this method returns. If other file-like objects are provided as arguments, input or output will automatically be redirected to them. The special value ``DEVNULL`` can be used to provide no input or discard all output, and the special value ``STDOUT`` can be provided as ``stderr`` to send its output to the same stream as ``stdout``. In addition to the arguments below, all arguments to :meth:`create_session` except for ``session_factory`` are supported and have the same meaning. :param int bufsize: (optional) Buffer size to use when feeding data from a file to stdin :param input: (optional) Input data to feed to standard input of the remote process. If specified, this argument takes precedence over stdin. Data should be a str if encoding is set, or bytes if not. :param stdin: (optional) A filename, file-like object, file descriptor, socket, or :class:`SSHReader` to feed to standard input of the remote process, or ``DEVNULL`` to provide no input. :param stdout: (optional) A filename, file-like object, file descriptor, socket, or :class:`SSHWriter` to feed standard output of the remote process to, or ``DEVNULL`` to discard this output. :param stderr: (optional) A filename, file-like object, file descriptor, socket, or :class:`SSHWriter` to feed standard error of the remote process to, ``DEVNULL`` to discard this output, or ``STDOUT`` to feed standard error to the same place as stdout. :type input: str or bytes :returns: :class:`SSHClientProcess` :raises: :exc:`ChannelOpenError` if the channel can't be opened """ chan, process = yield from self.create_session(SSHClientProcess, *args, **kwargs) if input: chan.write(input) chan.write_eof() stdin = None yield from process.redirect(stdin, stdout, stderr, bufsize) return process # pylint: enable=redefined-builtin @asyncio.coroutine def run(self, *args, check=False, **kwargs): """Run a command on the remote system and collect its output This method is a coroutine wrapper around :meth:`create_process` which can be used to run a process to completion when no interactivity is needed. All of the arguments to :meth:`create_process` can be passed in to provide input or redirect stdin, stdout, and stderr, but this method waits until the process exits and returns an :class:`SSHCompletedProcess` object with the exit status or signal information and the output to stdout and stderr (if not redirected). If the check argument is set to ``True``, a non-zero exit status from the remote process will trigger the :exc:`ProcessError` exception to be raised. In addition to the argument below, all arguments to :meth:`create_process` are supported and have the same meaning. :param bool check: (optional) Whether or not to raise :exc:`ProcessError` when a non-zero exit status is returned :returns: :class:`SSHCompletedProcess` :raises: | :exc:`ChannelOpenError` if the session can't be opened | :exc:`ProcessError` if checking non-zero exit status """ process = yield from self.create_process(*args, **kwargs) return (yield from process.wait(check)) @asyncio.coroutine def create_connection(self, session_factory, remote_host, remote_port, orig_host='', orig_port=0, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH TCP direct connection This method is a coroutine which can be called to request that the server open a new outbound TCP connection to the specified destination host and port. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHTCPSession` object created by ``session_factory``. Optional arguments include the host and port of the original client opening the connection when performing TCP port forwarding. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application send and receive string data. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param callable session_factory: A callable which returns an :class:`SSHClientSession` object that will be created to handle activity on this session :param str remote_host: The remote hostname or address to connect to :param int remote_port: The remote port number to connect to :param str orig_host: (optional) The hostname or address of the client requesting the connection :param int orig_port: (optional) The port number of the client requesting the connection :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :returns: an :class:`SSHTCPChannel` and :class:`SSHTCPSession` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan = self.create_tcp_channel(encoding, window, max_pktsize) session = yield from chan.connect(session_factory, remote_host, remote_port, orig_host, orig_port) return chan, session @asyncio.coroutine def open_connection(self, *args, **kwargs): """Open an SSH TCP direct connection This method is a coroutine wrapper around :meth:`create_connection` designed to provide a "high-level" stream interface for creating an SSH TCP direct connection. Instead of taking a ``session_factory`` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of ``session_factory``, all of the arguments to :meth:`create_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan, session = yield from self.create_connection(SSHTCPStreamSession, *args, **kwargs) return SSHReader(session, chan), SSHWriter(session, chan) @asyncio.coroutine def create_server(self, session_factory, listen_host, listen_port, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create a remote SSH TCP listener This method is a coroutine which can be called to request that the server listen on the specified remote address and port for incoming TCP connections. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the listener. If the request fails, ``None`` is returned. :param session_factory: A callable or coroutine which takes arguments of the original host and port of the client and decides whether to accept the connection or not, either returning an :class:`SSHTCPSession` object used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :param str listen_host: The hostname or address on the remote host to listen on :param int listen_port: The port number on the remote host to listen on :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :type session_factory: callable or coroutine :returns: :class:`SSHListener` or ``None`` if the listener can't be opened """ listen_host = listen_host.lower() pkttype, packet = yield from self._make_global_request( b'tcpip-forward', String(listen_host), UInt32(listen_port)) if pkttype == MSG_REQUEST_SUCCESS: if listen_port == 0: listen_port = packet.get_uint32() dynamic = True else: # OpenSSH 6.8 introduced a bug which causes the reply # to contain an extra uint32 value of 0 when non-dynamic # ports are requested, causing the check_end() call below # to fail. This check works around this problem. if len(packet.get_remaining_payload()) == 4: # pragma: no cover packet.get_uint32() dynamic = False packet.check_end() listener = SSHTCPClientListener(self, self._loop, session_factory, listen_host, listen_port, encoding, window, max_pktsize) if dynamic: self._dynamic_remote_listeners[listen_host] = listener self._remote_listeners[listen_host, listen_port] = listener return listener else: packet.check_end() return None @asyncio.coroutine def start_server(self, handler_factory, *args, **kwargs): """Start a remote SSH TCP listener This method is a coroutine wrapper around :meth:`create_server` designed to provide a "high-level" stream interface for creating remote SSH TCP listeners. Instead of taking a ``session_factory`` argument for constructing an object which will handle activity on the session via callbacks, it takes a ``handler_factory`` which returns a callable or coroutine that will be passed :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on each new connection which arrives. Like :meth:`create_server`, ``handler_factory`` can also raise :exc:`ChannelOpenError` if the connection should not be accepted. With the exception of ``handler_factory`` replacing ``session_factory``, all of the arguments to :meth:`create_server` are supported and have the same meaning here. :param handler_factory: A callable or coroutine which takes arguments of the original host and port of the client and decides whether to accept the connection or not, either returning a callback or coroutine used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :type handler_factory: callable or coroutine :returns: :class:`SSHListener` or ``None`` if the listener can't be opened """ def session_factory(orig_host, orig_port): """Return a TCP stream session handler""" return SSHTCPStreamSession(handler_factory(orig_host, orig_port)) return (yield from self.create_server(session_factory, *args, **kwargs)) @asyncio.coroutine def create_unix_connection(self, session_factory, remote_path, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH UNIX domain socket direct connection This method is a coroutine which can be called to request that the server open a new outbound UNIX domain socket connection to the specified destination path. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHUNIXSession` object created by ``session_factory``. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application send and receive string data. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param callable session_factory: A callable which returns an :class:`SSHClientSession` object that will be created to handle activity on this session :param str remote_path: The remote path to connect to :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :returns: an :class:`SSHUNIXChannel` and :class:`SSHUNIXSession` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan = self.create_unix_channel(encoding, window, max_pktsize) session = yield from chan.connect(session_factory, remote_path) return chan, session @asyncio.coroutine def open_unix_connection(self, *args, **kwargs): """Open an SSH UNIX domain socket direct connection This method is a coroutine wrapper around :meth:`create_unix_connection` designed to provide a "high-level" stream interface for creating an SSH UNIX domain socket direct connection. Instead of taking a ``session_factory`` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of ``session_factory``, all of the arguments to :meth:`create_unix_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan, session = \ yield from self.create_unix_connection(SSHUNIXStreamSession, *args, **kwargs) return SSHReader(session, chan), SSHWriter(session, chan) @asyncio.coroutine def create_unix_server(self, session_factory, listen_path, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create a remote SSH UNIX domain socket listener This method is a coroutine which can be called to request that the server listen on the specified remote path for incoming UNIX domain socket connections. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the listener. If the request fails, ``None`` is returned. :param session_factory: A callable or coroutine which takes arguments of the original host and port of the client and decides whether to accept the connection or not, either returning an :class:`SSHUNIXSession` object used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :param str listen_path: The path on the remote host to listen on :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :type session_factory: callable or coroutine :returns: :class:`SSHListener` or ``None`` if the listener can't be opened """ pkttype, packet = yield from self._make_global_request( b'streamlocal-forward@openssh.com', String(listen_path)) packet.check_end() if pkttype == MSG_REQUEST_SUCCESS: listener = SSHUNIXClientListener(self, self._loop, session_factory, listen_path, encoding, window, max_pktsize) self._remote_listeners[listen_path] = listener return listener else: return None @asyncio.coroutine def start_unix_server(self, handler_factory, *args, **kwargs): """Start a remote SSH UNIX domain socket listener This method is a coroutine wrapper around :meth:`create_unix_server` designed to provide a "high-level" stream interface for creating remote SSH UNIX domain socket listeners. Instead of taking a ``session_factory`` argument for constructing an object which will handle activity on the session via callbacks, it takes a ``handler_factory`` which returns a callable or coroutine that will be passed :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on each new connection which arrives. Like :meth:`create_unix_server`, ``handler_factory`` can also raise :exc:`ChannelOpenError` if the connection should not be accepted. With the exception of ``handler_factory`` replacing ``session_factory``, all of the arguments to :meth:`create_unix_server` are supported and have the same meaning here. :param handler_factory: A callable or coroutine which takes arguments of the original host and port of the client and decides whether to accept the connection or not, either returning a callback or coroutine used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :type handler_factory: callable or coroutine :returns: :class:`SSHListener` or ``None`` if the listener can't be opened """ def session_factory(): """Return a UNIX domain socket stream session handler""" return SSHUNIXStreamSession(handler_factory()) return (yield from self.create_unix_server(session_factory, *args, **kwargs)) @asyncio.coroutine def create_ssh_connection(self, client_factory, host, port=_DEFAULT_PORT, *args, **kwargs): """Create a tunneled SSH client connection This method is a coroutine which can be called to open an SSH client connection to the requested host and port tunneled inside this already established connection. It takes all the same arguments as :func:`create_connection` but requests that the upstream SSH server open the connection rather than connecting directly. """ return (yield from create_connection(client_factory, host, port, *args, tunnel=self, **kwargs)) @asyncio.coroutine def connect_ssh(self, host, port=_DEFAULT_PORT, *args, **kwargs): """Make a tunneled SSH client connection This method is a coroutine which can be called to open an SSH client connection to the requested host and port tunneled inside this already established connection. It takes all the same arguments as :func:`connect` but requests that the upstream SSH server open the connection rather than connecting directly. """ return (yield from connect(host, port, *args, tunnel=self, **kwargs)) @asyncio.coroutine def forward_remote_port(self, listen_host, listen_port, dest_host, dest_port): """Set up remote port forwarding This method is a coroutine which attempts to set up port forwarding from a remote listening port to a local host and port via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. If the request fails, ``None`` is returned. :param str listen_host: The hostname or address on the remote host to listen on :param int listen_port: The port number on the remote host to listen on :param str dest_host: The hostname or address to forward connections to :param int dest_port: The port number to forward connections to :returns: :class:`SSHListener` or ``None`` if the listener can't be opened """ def session_factory(orig_host, orig_port): """Return an SSHTCPSession used to do remote port forwarding""" # pylint: disable=unused-argument return self.forward_connection(dest_host, dest_port) return (yield from self.create_server(session_factory, listen_host, listen_port)) @asyncio.coroutine def forward_remote_path(self, listen_path, dest_path): """Set up remote UNIX domain socket forwarding This method is a coroutine which attempts to set up UNIX domain socket forwarding from a remote listening path to a local path via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. If the request fails, ``None`` is returned. :param str listen_path: The path on the remote host to listen on :param str dest_path: The path on the local host to forward connections to :returns: :class:`SSHListener` or ``None`` if the listener can't be opened """ def session_factory(): """Return an SSHUNIXSession used to do remote path forwarding""" return self.forward_unix_connection(dest_path) return (yield from self.create_unix_server(session_factory, listen_path)) @async_context_manager def start_sftp_client(self, path_encoding='utf-8', path_errors='strict'): """Start an SFTP client This method is a coroutine which attempts to start a secure file transfer session. If it succeeds, it returns an :class:`SFTPClient` object which can be used to copy and access files on the remote host. An optional Unicode encoding can be specified for sending and receiving pathnames, defaulting to UTF-8 with strict error checking. If an encoding of ``None`` is specified, pathnames will be left as bytes rather than being converted to & from strings. :param str path_encoding: The Unicode encoding to apply when sending and receiving remote pathnames :param str path_errors: The error handling strategy to apply on encode/decode errors :returns: :class:`SFTPClient` :raises: :exc:`SFTPError` if the session can't be opened """ writer, reader, _ = yield from self.open_session(subsystem='sftp', encoding=None) return (yield from start_sftp_client(self, self._loop, reader, writer, path_encoding, path_errors)) class SSHServerConnection(SSHConnection): """SSH server connection This class represents an SSH server connection. During authentication, :meth:`send_auth_banner` can be called to send an authentication banner to the client. Once authenticated, :class:`SSHServer` objects wishing to create session objects with non-default channel properties can call :meth:`create_server_channel` from their :meth:`session_requested() ` method and return a tuple of the :class:`SSHServerChannel` object returned from that and either an :class:`SSHServerSession` object or a coroutine which returns an :class:`SSHServerSession`. Similarly, :class:`SSHServer` objects wishing to create TCP connection objects with non-default channel properties can call :meth:`create_tcp_channel` from their :meth:`connection_requested() ` method and return a tuple of the :class:`SSHTCPChannel` object returned from that and either an :class:`SSHTCPSession` object or a coroutine which returns an :class:`SSHTCPSession`. :class:`SSHServer` objects wishing to create UNIX domain socket connection objects with non-default channel properties can call :meth:`create_unix_channel` from the :meth:`unix_connection_requested() ` method and return a tuple of the :class:`SSHUNIXChannel` object returned from that and either an :class:`SSHUNIXSession` object or a coroutine which returns an :class:`SSHUNIXSession`. """ def __init__(self, server_factory, loop, server_version, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, rekey_bytes, rekey_seconds, server_host_keys, authorized_client_keys, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, gss_host, allow_pty, line_editor, line_history, x11_forwarding, x11_auth_path, agent_forwarding, process_factory, session_factory, session_encoding, sftp_factory, allow_scp, window, max_pktsize, login_timeout): super().__init__(server_factory, loop, server_version, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, rekey_bytes, rekey_seconds, server=True) self._server_host_keys = server_host_keys self._server_host_key_algs = server_host_keys.keys() self._client_keys = authorized_client_keys self._x509_trusted_certs = x509_trusted_certs self._x509_trusted_cert_paths = x509_trusted_cert_paths self._x509_purposes = x509_purposes self._allow_pty = allow_pty self._line_editor = line_editor self._line_history = line_history self._x11_forwarding = x11_forwarding self._x11_auth_path = x11_auth_path self._agent_forwarding = agent_forwarding self._process_factory = process_factory self._session_factory = session_factory self._session_encoding = session_encoding self._sftp_factory = sftp_factory self._allow_scp = allow_scp self._window = window self._max_pktsize = max_pktsize if gss_host: try: self._gss = GSSServer(gss_host) self._gss_mic_auth = True except GSSError: pass if login_timeout: self._login_timer = loop.call_later(login_timeout, self._login_timer_callback) else: self._login_timer = None self._server_host_key = None self._key_options = {} self._cert_options = None self._kbdint_password_auth = False self._agent_listener = None def _cleanup(self, exc): """Clean up this server connection""" if self._agent_listener: self._agent_listener.close() self._agent_listener = None self._cancel_login_timer() super()._cleanup(exc) def _cancel_login_timer(self): """Cancel the login timer""" if self._login_timer: self._login_timer.cancel() self._login_timer = None def _login_timer_callback(self): """Close the connection if authentication hasn't completed yet""" self._login_timer = None self.connection_lost(DisconnectError(DISC_CONNECTION_LOST, 'Login timeout expired')) def _connection_made(self): """Handle the opening of a new connection""" pass def _choose_server_host_key(self, peer_host_key_algs): """Choose the server host key to use Given a list of host key algorithms supported by the client, select the first compatible server host key we have and return whether or not we were able to find a match. """ for alg in peer_host_key_algs: keypair = self._server_host_keys.get(alg) if keypair: if alg != keypair.algorithm: keypair.set_sig_algorithm(alg) self._server_host_key = keypair return True return False def get_server_host_key(self): """Return the chosen server host key This method returns a keypair object containing the chosen server host key and a corresponding public key or certificate. """ return self._server_host_key def gss_kex_auth_supported(self): """Return whether GSS key exchange authentication is supported""" return self._gss_kex_auth and self._gss.complete def gss_mic_auth_supported(self): """Return whether GSS MIC authentication is supported""" return self._gss_mic_auth @asyncio.coroutine def validate_gss_principal(self, username, user_principal, host_principal): """Validate the GSS principal name for the specified user Return whether the user principal acquired during GSS authentication is valid for the specified user. """ result = self._owner.validate_gss_principal(username, user_principal, host_principal) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _validate_openssh_certificate(self, username, cert): """Validate an OpenSSH client certificate for the specified user""" options = None if self._client_keys: options = self._client_keys.validate(cert.signing_key, self._peer_addr, cert.principals, ca=True) if options is None: result = self._owner.validate_ca_key(username, cert.signing_key) if asyncio.iscoroutine(result): result = yield from result if not result: return None options = {} self._key_options = options if self.get_key_option('principals'): username = None try: cert.validate(CERT_TYPE_USER, username) except ValueError: return None allowed_addresses = self.get_certificate_option('source-address') if allowed_addresses: ip = ip_address(self._peer_addr) if not any(ip in network for network in allowed_addresses): return None self._cert_options = cert.options return cert.key @asyncio.coroutine def _validate_x509_certificate_chain(self, username, cert): """Validate an X.509 client certificate for the specified user""" if not self._client_keys: return None options, trusted_cert = \ self._client_keys.validate_x509(cert, self._peer_addr) if options is None: return None self._key_options = options if self.get_key_option('principals'): username = None if trusted_cert: trusted_certs = self._x509_trusted_certs + [trusted_cert] else: trusted_certs = self._x509_trusted_certs try: cert.validate_chain(trusted_certs, self._x509_trusted_cert_paths, None, self._x509_purposes, user_principal=username) except ValueError: return None return cert.key @asyncio.coroutine def _validate_client_certificate(self, username, key_data): """Validate a client certificate for the specified user""" try: cert = decode_ssh_certificate(key_data) except KeyImportError: return None if cert.is_x509_chain: return self._validate_x509_certificate_chain(username, cert) else: return self._validate_openssh_certificate(username, cert) @asyncio.coroutine def _validate_client_public_key(self, username, key_data): """Validate a client public key for the specified user""" try: key = decode_ssh_public_key(key_data) except KeyImportError: return None options = None if self._client_keys: options = self._client_keys.validate(key, self._peer_addr) if options is None: result = self._owner.validate_public_key(username, key) if asyncio.iscoroutine(result): result = yield from result if not result: return None options = {} self._key_options = options return key def public_key_auth_supported(self): """Return whether or not public key authentication is supported""" return (bool(self._client_keys) or self._owner.public_key_auth_supported()) @asyncio.coroutine def validate_public_key(self, username, key_data, msg, signature): """Validate the public key or certificate for the specified user This method validates that the public key or certificate provided is allowed for the specified user. If msg and signature are provided, the key is used to also validate the message signature. It returns ``True`` when the key is allowed and the signature (if present) is valid. Otherwise, it returns ``False``. """ key = ((yield from self._validate_client_certificate(username, key_data)) or (yield from self._validate_client_public_key(username, key_data))) if key is None: return False elif msg: return key.verify(String(self._session_id) + msg, signature) else: return True def password_auth_supported(self): """Return whether or not password authentication is supported""" return self._owner.password_auth_supported() @asyncio.coroutine def validate_password(self, username, password): """Return whether password is valid for this user""" result = self._owner.validate_password(username, password) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def change_password(self, username, old_password, new_password): """Handle a password change request for a user""" result = self._owner.change_password(username, old_password, new_password) if asyncio.iscoroutine(result): result = yield from result return result def kbdint_auth_supported(self): """Return whether or not keyboard-interactive authentication is supported""" result = self._owner.kbdint_auth_supported() if result is True: return True elif (result is NotImplemented and self._owner.password_auth_supported()): self._kbdint_password_auth = True return True else: return False @asyncio.coroutine def get_kbdint_challenge(self, username, lang, submethods): """Return a keyboard-interactive auth challenge""" if self._kbdint_password_auth: result = ('', '', DEFAULT_LANG, (('Password:', False),)) else: result = self._owner.get_kbdint_challenge(username, lang, submethods) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def validate_kbdint_response(self, username, responses): """Return whether the keyboard-interactive response is valid for this user""" if self._kbdint_password_auth: if len(responses) != 1: return False try: result = self._owner.validate_password(username, responses[0]) if asyncio.iscoroutine(result): result = yield from result except PasswordChangeRequired: # Don't support password change requests for now in # keyboard-interactive auth result = False else: result = self._owner.validate_kbdint_response(username, responses) if asyncio.iscoroutine(result): result = yield from result return result def _process_session_open(self, packet): """Process an incoming session open request""" packet.check_end() if self._process_factory or self._session_factory or self._sftp_factory: chan = self.create_server_channel(self._session_encoding, self._window, self._max_pktsize) if self._process_factory: session = SSHServerProcess(self._process_factory, self._sftp_factory, self._allow_scp) else: session = SSHServerStreamSession(self._session_factory, self._sftp_factory, self._allow_scp) else: result = self._owner.session_requested() if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Session refused') if isinstance(result, tuple): chan, result = result else: chan = self.create_server_channel(self._session_encoding, self._window, self._max_pktsize) if callable(result): session = SSHServerStreamSession(result, None, False) else: session = result return chan, session def _process_direct_tcpip_open(self, packet): """Process an incoming direct TCP/IP open request""" dest_host = packet.get_string() dest_port = packet.get_uint32() orig_host = packet.get_string() orig_port = packet.get_uint32() packet.check_end() try: dest_host = dest_host.decode('utf-8') orig_host = orig_host.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid direct ' 'TCP/IP channel open request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Port forwarding not permitted') permitted_opens = self.get_key_option('permitopen') if permitted_opens and \ (dest_host, dest_port) not in permitted_opens and \ (dest_host, None) not in permitted_opens: raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Port forwarding not permitted to %s ' 'port %s' % (dest_host, dest_port)) result = self._owner.connection_requested(dest_host, dest_port, orig_host, orig_port) if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Connection refused') if result is True: result = self.forward_connection(dest_host, dest_port) if isinstance(result, tuple): chan, result = result else: chan = self.create_tcp_channel() if callable(result): session = SSHTCPStreamSession(result) else: session = result chan.set_inbound_peer_names(dest_host, dest_port, orig_host, orig_port) return chan, session def _process_tcpip_forward_global_request(self, packet): """Process an incoming TCP/IP port forwarding request""" listen_host = packet.get_string() listen_port = packet.get_uint32() packet.check_end() try: listen_host = listen_host.decode('utf-8').lower() except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid TCP/IP port ' 'forward request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): self._report_global_response(False) return result = self._owner.server_requested(listen_host, listen_port) if not result: self._report_global_response(False) return if result is True: result = self.forward_local_port(listen_host, listen_port, listen_host, listen_port) self.create_task(self._finish_port_forward(result, listen_host, listen_port)) @asyncio.coroutine def _finish_port_forward(self, listener, listen_host, listen_port): """Finish processing a TCP/IP port forwarding request""" if asyncio.iscoroutine(listener): try: listener = yield from listener except OSError: listener = None if listener: if listen_port == 0: listen_port = listener.get_port() result = UInt32(listen_port) else: result = True self._local_listeners[listen_host, listen_port] = listener self._report_global_response(result) else: self._report_global_response(False) def _process_cancel_tcpip_forward_global_request(self, packet): """Process a request to cancel TCP/IP port forwarding""" listen_host = packet.get_string() listen_port = packet.get_uint32() packet.check_end() try: listen_host = listen_host.decode('utf-8').lower() except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid TCP/IP cancel ' 'forward request') from None try: listener = self._local_listeners.pop((listen_host, listen_port)) except KeyError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'TCP/IP listener ' 'not found') from None listener.close() self._report_global_response(True) def _process_direct_streamlocal_at_openssh_dot_com_open(self, packet): """Process an incoming direct UNIX domain socket open request""" dest_path = packet.get_string() # OpenSSH appears to have a bug which sends this extra data _ = packet.get_string() # originator _ = packet.get_uint32() # originator_port packet.check_end() try: dest_path = dest_path.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid direct UNIX ' 'domain channel open request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Port forwarding not permitted') result = self._owner.unix_connection_requested(dest_path) if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Connection refused') if result is True: result = self.forward_unix_connection(dest_path) if isinstance(result, tuple): chan, result = result else: chan = self.create_unix_channel() if callable(result): session = SSHUNIXStreamSession(result) else: session = result chan.set_inbound_peer_names(dest_path) return chan, session def _process_streamlocal_forward_at_openssh_dot_com_global_request(self, packet): """Process an incoming UNIX domain socket forwarding request""" listen_path = packet.get_string() packet.check_end() try: listen_path = listen_path.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid UNIX domain ' 'socket forward request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): self._report_global_response(False) return result = self._owner.unix_server_requested(listen_path) if not result: self._report_global_response(False) return if result is True: result = self.forward_local_path(listen_path, listen_path) self.create_task(self._finish_path_forward(result, listen_path)) @asyncio.coroutine def _finish_path_forward(self, listener, listen_path): """Finish processing a UNIX domain socket forwarding request""" if asyncio.iscoroutine(listener): try: listener = yield from listener except OSError: listener = None if listener: self._local_listeners[listen_path] = listener self._report_global_response(True) else: self._report_global_response(False) def _process_cancel_streamlocal_forward_at_openssh_dot_com_global_request( self, packet): """Process a request to cancel UNIX domain socket forwarding""" listen_path = packet.get_string() packet.check_end() try: listen_path = listen_path.decode('utf-8') except UnicodeDecodeError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid UNIX domain ' 'cancel forward request') from None try: listener = self._local_listeners.pop(listen_path) except KeyError: raise DisconnectError(DISC_PROTOCOL_ERROR, 'UNIX domain listener ' 'not found') from None listener.close() self._report_global_response(True) @asyncio.coroutine def attach_x11_listener(self, chan, auth_proto, auth_data, screen): """Attach a channel to a remote X11 display""" if (not self._x11_forwarding or not self.check_key_permission('X11-forwarding') or not self.check_certificate_permission('X11-forwarding')): return None if not self._x11_listener: self._x11_listener = yield from create_x11_server_listener( self, self._loop, self._x11_auth_path, auth_proto, auth_data) if self._x11_listener: return self._x11_listener.attach(chan, screen) else: return None def detach_x11_listener(self, chan): """Detach a session from a remote X11 listener""" if self._x11_listener: if self._x11_listener.detach(chan): self._x11_listener = None def create_agent_listener(self): """Create a listener for forwarding ssh-agent connections""" if (not self._agent_forwarding or not self.check_key_permission('agent-forwarding') or not self.check_certificate_permission('agent-forwarding')): return False if not self._agent_listener: self._agent_listener = yield from create_agent_listener(self, self._loop) return bool(self._agent_listener) def get_agent_path(self): """Return the path of the ssh-agent listener, if one exists""" if self._agent_listener: return self._agent_listener.get_path() else: return None def send_auth_banner(self, msg, lang=DEFAULT_LANG): """Send an authentication banner to the client This method can be called to send an authentication banner to the client, displaying information while authentication is in progress. It is an error to call this method after the authentication is complete. :param str msg: The message to display :param str lang: The language the message is in :raises: :exc:`OSError` if authentication is already completed """ if self._auth_complete: raise OSError('Authentication already completed') self.send_packet(Byte(MSG_USERAUTH_BANNER), String(msg), String(lang)) def set_authorized_keys(self, authorized_keys): """Set the keys trusted for client public key authentication This method can be called to set the trusted user and CA keys for client public key authentication. It should generally be called from the :meth:`begin_auth ` method of :class:`SSHServer` to set the appropriate keys for the user attempting to authenticate. :param authorized_keys: The keys to trust for client public key authentication :type authorized_keys: *see* :ref:`SpecifyingAuthorizedKeys` """ if isinstance(authorized_keys, str): authorized_keys = read_authorized_keys(authorized_keys) self._client_keys = authorized_keys def get_key_option(self, option, default=None): """Return option from authorized_keys If a client key or certificate was presented during authentication, this method returns the value of the requested option in the corresponding authorized_keys entry if it was set. Otherwise, it returns the default value provided. The following standard options are supported: | command (string) | environment (dictionary of name/value pairs) | from (list of host patterns) | permitopen (list of host/port tuples) | principals (list of usernames) Non-standard options are also supported and will return the value ``True`` if the option is present without a value or return a list of strings containing the values associated with each occurrence of that option name. If the option is not present, the specified default value is returned. :param str option: The name of the option to look up. :param default: The default value to return if the option is not present. :returns: The value of the option in authorized_keys, if set """ return self._key_options.get(option, default) def check_key_permission(self, permission): """Check permissions in authorized_keys If a client key or certificate was presented during authentication, this method returns whether the specified permission is allowed by the corresponding authorized_keys entry. By default, all permissions are granted, but they can be revoked by specifying an option starting with 'no-' without a value. The following standard options are supported: | X11-forwarding | agent-forwarding | port-forwarding | pty | user-rc AsyncSSH internally enforces X11-forwarding, agent-forwarding, port-forwarding and pty permissions but ignores user-rc since it does not implement that feature. Non-standard permissions can also be checked, as long as the option follows the convention of starting with 'no-'. :param str permission: The name of the permission to check (without the 'no-'). :returns: A bool indicating if the permission is granted. """ return not self._key_options.get('no-' + permission, False) def get_certificate_option(self, option, default=None): """Return option from user certificate If a user certificate was presented during authentication, this method returns the value of the requested option in the certificate if it was set. Otherwise, it returns the default value provided. The following options are supported: | force-command (string) | source-address (list of CIDR-style IP network addresses) :param str option: The name of the option to look up. :param default: The default value to return if the option is not present. :returns: The value of the option in the user certificate, if set """ if self._cert_options is not None: return self._cert_options.get(option, default) else: return default def check_certificate_permission(self, permission): """Check permissions in user certificate If a user certificate was presented during authentication, this method returns whether the specified permission was granted in the certificate. Otherwise, it acts as if all permissions are granted and returns ``True``. The following permissions are supported: | X11-forwarding | agent-forwarding | port-forwarding | pty | user-rc AsyncSSH internally enforces agent-forwarding, port-forwarding and pty permissions but ignores the other values since it does not implement those features. :param str permission: The name of the permission to check (without the 'permit-'). :returns: A bool indicating if the permission is granted. """ if self._cert_options is not None: return self._cert_options.get('permit-' + permission, False) else: return True def create_server_channel(self, encoding='utf-8', window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH server channel for a new SSH session This method can be called by :meth:`session_requested() ` to create an :class:`SSHServerChannel` with the desired encoding, window, and max packet size for a newly created SSH server session. :param str encoding: (optional) The Unicode encoding to use for data exchanged on the session, defaulting to UTF-8 (ISO 10646) format. If ``None`` is passed in, the application can send and receive raw bytes. :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :returns: :class:`SSHServerChannel` """ return SSHServerChannel(self, self._loop, self._allow_pty, self._line_editor, self._line_history, encoding, window, max_pktsize) @asyncio.coroutine def create_connection(self, session_factory, remote_host, remote_port, orig_host='', orig_port=0, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH TCP forwarded connection This method is a coroutine which can be called to notify the client about a new inbound TCP connection arriving on the specified remote host and port. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHTCPSession` object created by ``session_factory``. Optional arguments include the host and port of the original client opening the connection when performing TCP port forwarding. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application send and receive string data. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param callable session_factory: A callable which returns an :class:`SSHClientSession` object that will be created to handle activity on this session :param str remote_host: The hostname or address the connection was received on :param int remote_port: The port number the connection was received on :param str orig_host: (optional) The hostname or address of the client requesting the connection :param int orig_port: (optional) The port number of the client requesting the connection :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :returns: an :class:`SSHTCPChannel` and :class:`SSHTCPSession` """ chan = self.create_tcp_channel(encoding, window, max_pktsize) session = yield from chan.accept(session_factory, remote_host, remote_port, orig_host, orig_port) return chan, session @asyncio.coroutine def open_connection(self, *args, **kwargs): """Open an SSH TCP forwarded connection This method is a coroutine wrapper around :meth:`create_connection` designed to provide a "high-level" stream interface for creating an SSH TCP forwarded connection. Instead of taking a ``session_factory`` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of ``session_factory``, all of the arguments to :meth:`create_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` """ chan, session = yield from self.create_connection(SSHTCPStreamSession, *args, **kwargs) return SSHReader(session, chan), SSHWriter(session, chan) @asyncio.coroutine def create_unix_connection(self, session_factory, remote_path, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH UNIX domain socket forwarded connection This method is a coroutine which can be called to notify the client about a new inbound UNIX domain socket connection arriving on the specified remote path. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHUNIXSession` object created by ``session_factory``. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application send and receive string data. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param callable session_factory: A callable which returns an :class:`SSHClientSession` object that will be created to handle activity on this session :param str remote_path: The path the connection was received on :param str encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param int window: (optional) The receive window size for this session :param int max_pktsize: (optional) The maximum packet size for this session :returns: an :class:`SSHTCPChannel` and :class:`SSHUNIXSession` """ chan = self.create_unix_channel(encoding, window, max_pktsize) session = yield from chan.accept(session_factory, remote_path) return chan, session @asyncio.coroutine def open_unix_connection(self, *args, **kwargs): """Open an SSH UNIX domain socket forwarded connection This method is a coroutine wrapper around :meth:`create_unix_connection` designed to provide a "high-level" stream interface for creating an SSH UNIX domain socket forwarded connection. Instead of taking a ``session_factory`` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of ``session_factory``, all of the arguments to :meth:`create_unix_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` """ chan, session = \ yield from self.create_unix_connection(SSHUNIXStreamSession, *args, **kwargs) return SSHReader(session, chan), SSHWriter(session, chan) @asyncio.coroutine def create_x11_connection(self, session_factory, orig_host='', orig_port=0, *, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create an SSH X11 forwarded connection""" chan = self.create_x11_channel(None, window, max_pktsize) session = yield from chan.open(session_factory, orig_host, orig_port) return chan, session @asyncio.coroutine def create_agent_connection(self, session_factory, *, encoding=None, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE): """Create a forwarded ssh-agent connection back to the client""" if not self._agent_listener: raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Agent forwarding not permitted') chan = self.create_agent_channel(encoding, window, max_pktsize) session = yield from chan.open(session_factory) return chan, session @asyncio.coroutine def open_agent_connection(self): """Open a forwarded ssh-agent connection back to the client""" chan, session = \ yield from self.create_agent_connection(SSHUNIXStreamSession) return SSHReader(session, chan), SSHWriter(session, chan) @asyncio.coroutine def create_connection(client_factory, host, port=_DEFAULT_PORT, *, loop=None, tunnel=None, family=0, flags=0, local_addr=None, known_hosts=(), x509_trusted_certs=(), x509_trusted_cert_paths=(), x509_purposes='secureShellServer', username=None, password=None, client_keys=(), passphrase=None, gss_host=(), gss_delegate_creds=False, agent_path=(), agent_forwarding=False, client_version=(), kex_algs=(), encryption_algs=(), mac_algs=(), compression_algs=(), signature_algs=(), rekey_bytes=_DEFAULT_REKEY_BYTES, rekey_seconds=_DEFAULT_REKEY_SECONDS): """Create an SSH client connection This function is a coroutine which can be run to create an outbound SSH client connection to the specified host and port. When successful, the following steps occur: 1. The connection is established and an :class:`SSHClientConnection` object is created to represent it. 2. The ``client_factory`` is called without arguments and should return an :class:`SSHClient` object. 3. The client object is tied to the connection and its :meth:`connection_made() ` method is called. 4. The SSH handshake and authentication process is initiated, calling methods on the client object if needed. 5. When authentication completes successfully, the client's :meth:`auth_completed() ` method is called. 6. The coroutine returns the ``(connection, client)`` pair. At this point, the connection is ready for sessions to be opened or port forwarding to be set up. If an error occurs, it will be raised as an exception and the partially open connection and client objects will be cleaned up. .. note:: Unlike :func:`socket.create_connection`, asyncio calls to create a connection do not support a ``timeout`` parameter. However, asyncio calls can be wrapped in a call to :func:`asyncio.wait_for` or :func:`asyncio.wait` which takes a timeout and provides equivalent functionality. :param callable client_factory: A callable which returns an :class:`SSHClient` object that will be tied to the connection :param str host: The hostname or address to connect to :param int port: (optional) The port number to connect to. If not specified, the default SSH port is used. :param loop: (optional) The event loop to use when creating the connection. If not specified, the default event loop is used. :param tunnel: (optional) An existing SSH client connection that this new connection should be tunneled over. If set, a direct TCP/IP tunnel will be opened over this connection to the requested host and port rather than connecting directly via TCP. :param family: (optional) The address family to use when creating the socket. By default, the address family is automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting :param known_hosts: (optional) The list of keys which will be used to validate the server host key presented during the SSH handshake. If this is not specified, the keys will be looked up in the file :file:`.ssh/known_hosts`. If this is explicitly set to ``None``, server host key validation will be disabled. :param x509_trusted_certs: (optional) A list of certificates which should be trusted for X.509 server certificate authentication. If no trusted certificates are specified, an attempt will be made to load them from the file :file:`.ssh/ca-bundle.crt`. If this argument is explicitly set to ``None``, X.509 server certificate authentication will not be performed. .. note:: X.509 certificates to trust can also be provided through a :ref:`known_hosts ` file if they are converted into OpenSSH format. This allows their trust to be limited to only specific host names. :param x509_trusted_cert_paths: (optional) A list of path names to "hash directories" containing certificates which should be trusted for X.509 server certificate authentication. Each certificate should be in a separate file with a name of the form *hash.number*, where *hash* is the OpenSSL hash value of the certificate subject name and *number* is an integer counting up from zero if multiple certificates have the same hash. If no paths are specified, an attempt with be made to use the directory :file:`.ssh/crt` as a certificate hash directory. :param x509_purposes: (optional) A list of purposes allowed in the ExtendedKeyUsage of a certificate used for X.509 server certificate authentication, defulting to 'secureShellServer'. If this argument is explicitly set to ``None``, the server certificate's ExtendedKeyUsage will not be checked. :param str username: (optional) Username to authenticate as on the server. If not specified, the currently logged in user on the local machine will be used. :param str password: (optional) The password to use for client password authentication or keyboard-interactive authentication which prompts for a password. If this is not specified, client password authentication will not be performed. :param client_keys: (optional) A list of keys which will be used to authenticate this client via public key authentication. If no client keys are specified, an attempt will be made to get them from an ssh-agent process. If that is not available, an attempt will be made to load them from the files :file:`.ssh/id_ed25519`, :file:`.ssh/id_ecdsa`, :file:`.ssh/id_rsa`, and :file:`.ssh/id_dsa` in the user's home directory, with optional certificates loaded from the files :file:`.ssh/id_ed25519-cert.pub`, :file:`.ssh/id_ecdsa-cert.pub`, :file:`.ssh/id_rsa-cert.pub`, and :file:`.ssh/id_dsa-cert.pub`. If this argument is explicitly set to ``None``, client public key authentication will not be performed. :param str passphrase: (optional) The passphrase to use to decrypt client keys when loading them, if they are encrypted. If this is not specified, only unencrypted client keys can be loaded. If the keys passed into client_keys are already loaded, this argument is ignored. :param str gss_host: (optional) The principal name to use for the host in GSS key exchange and authentication. If not specified, this value will be the same as the ``host`` argument. If this argument is explicitly set to ``None``, GSS key exchange and authentication will not be performed. :param bool gss_delegate_creds: (optional) Whether or not to forward GSS credentials to the server being accessed. By default, GSS credential delegation is disabled. :param agent_path: (optional) The path of a UNIX domain socket to use to contact an ssh-agent process which will perform the operations needed for client public key authentication, or the :class:`SSHServerConnection` to use to forward ssh-agent requests over. If this is not specified and the environment variable ``SSH_AUTH_SOCK`` is set, its value will be used as the path. If ``client_keys`` is specified or this argument is explicitly set to ``None``, an ssh-agent will not be used. :param bool agent_forwarding: (optional) Whether or not to allow forwarding of ssh-agent requests from processes running on the server. By default, ssh-agent forwarding requests from the server are not allowed. :param str client_version: (optional) An ASCII string to advertise to the SSH server as the version of this client, defaulting to ``AsyncSSH`` and its version number. :param kex_algs: (optional) A list of allowed key exchange algorithms in the SSH handshake, taken from :ref:`key exchange algorithms ` :param encryption_algs: (optional) A list of encryption algorithms to use during the SSH handshake, taken from :ref:`encryption algorithms ` :param mac_algs: (optional) A list of MAC algorithms to use during the SSH handshake, taken from :ref:`MAC algorithms ` :param compression_algs: (optional) A list of compression algorithms to use during the SSH handshake, taken from :ref:`compression algorithms `, or ``None`` to disable compression :param signature_algs: (optional) A list of public key signature algorithms to use during the SSH handshake, taken from :ref:`signature algorithms ` :param int rekey_bytes: (optional) The number of bytes which can be sent before the SSH session key is renegotiated. This defaults to 1 GB. :param int rekey_seconds: (optional) The maximum time in seconds before the SSH session key is renegotiated. This defaults to 1 hour. :type tunnel: :class:`SSHClientConnection` :type family: ``socket.AF_UNSPEC``, ``socket.AF_INET``, or ``socket.AF_INET6`` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of str and int :type known_hosts: *see* :ref:`SpecifyingKnownHosts` :type x509_trusted_certs: *see* :ref:`SpecifyingCertificates` :type x509_trusted_cert_paths: list of str :type x509_purposes: *see* :ref:`SpecifyingX509Purposes` :type client_keys: *see* :ref:`SpecifyingPrivateKeys` :type agent_path: str or :class:`SSHServerConnection` :type kex_algs: list of str :type encryption_algs: list of str :type mac_algs: list of str :type compression_algs: list of str :type signature_algs: list of str :returns: An :class:`SSHClientConnection` and :class:`SSHClient` """ def conn_factory(): """Return an SSH client connection handler""" return SSHClientConnection(client_factory, loop, client_version, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, rekey_bytes, rekey_seconds, host, port, known_hosts, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, username, password, client_keys, gss_host, gss_delegate_creds, agent, agent_path, auth_waiter) if not client_factory: client_factory = SSHClient if not loop: loop = asyncio.get_event_loop() client_version = _validate_version(client_version) kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs = \ _validate_algs(kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, x509_trusted_certs is not None) if x509_trusted_certs is (): try: x509_trusted_certs = load_certificates( os.path.join(os.path.expanduser('~'), '.ssh', 'ca-bundle.crt')) except OSError: pass elif x509_trusted_certs is not None: x509_trusted_certs = load_certificates(x509_trusted_certs) if x509_trusted_cert_paths is (): path = os.path.join(os.path.expanduser('~'), '.ssh', 'crt') if os.path.isdir(path): x509_trusted_cert_paths = [path] elif x509_trusted_cert_paths: for path in x509_trusted_cert_paths: if not os.path.isdir(path): raise ValueError('Path not a directory: ' + str(path)) if username is None: username = getpass.getuser() if gss_host is (): gss_host = host agent = None if agent_path is (): agent_path = os.environ.get('SSH_AUTH_SOCK', None) if client_keys: client_keys = load_keypairs(client_keys, passphrase) elif client_keys is (): if agent_path: agent = yield from connect_agent(agent_path, loop=loop) if agent: client_keys = yield from agent.get_keys() else: agent_path = None if not client_keys: client_keys = load_default_keypairs(passphrase) if not agent_forwarding: agent_path = None auth_waiter = asyncio.Future(loop=loop) # pylint: disable=broad-except try: if tunnel: _, conn = yield from tunnel.create_connection(conn_factory, host, port) else: _, conn = yield from loop.create_connection(conn_factory, host, port, family=family, flags=flags, local_addr=local_addr) except Exception: if agent: agent.close() raise yield from auth_waiter return conn, conn.get_owner() @asyncio.coroutine def create_server(server_factory, host=None, port=_DEFAULT_PORT, *, loop=None, family=0, flags=socket.AI_PASSIVE, backlog=100, reuse_address=None, server_host_keys=None, passphrase=None, authorized_client_keys=None, x509_trusted_certs=(), x509_trusted_cert_paths=(), x509_purposes='secureShellClient', gss_host=(), allow_pty=True, line_editor=True, line_history=_DEFAULT_LINE_HISTORY, x11_forwarding=False, x11_auth_path=None, agent_forwarding=True, process_factory=None, session_factory=None, session_encoding='utf-8', sftp_factory=None, allow_scp=False, window=_DEFAULT_WINDOW, max_pktsize=_DEFAULT_MAX_PKTSIZE, server_version=(), kex_algs=(), encryption_algs=(), mac_algs=(), compression_algs=(), signature_algs=(), rekey_bytes=_DEFAULT_REKEY_BYTES, rekey_seconds=_DEFAULT_REKEY_SECONDS, login_timeout=_DEFAULT_LOGIN_TIMEOUT): """Create an SSH server This function is a coroutine which can be run to create an SSH server bound to the specified host and port. The return value is an object derived from :class:`asyncio.AbstractServer` which can be used to later shut down the server. :param callable server_factory: A callable which returns an :class:`SSHServer` object that will be created for each new inbound connection :param str host: (optional) The hostname or address to listen on. If not specified, listeners are created for all addresses. :param int port: (optional) The port number to listen on. If not specified, the default SSH port is used. :param loop: (optional) The event loop to use when creating the server. If not specified, the default event loop is used. :param family: (optional) The address family to use when creating the server. By default, the address families are automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host :param int backlog: (optional) The maximum number of queued connections allowed on listeners :param bool reuse_address: (optional) Whether or not to reuse a local socket in the TIME_WAIT state without waiting for its natural timeout to expire. If not specified, this will be automatically set to ``True`` on UNIX. :param server_host_keys: (optional) A list of private keys and optional certificates which can be used by the server as a host key. Either this argument or ``gss_host`` must be specified. If this is not specified, only GSS-based key exchange will be supported. :param str passphrase: (optional) The passphrase to use to decrypt server host keys when loading them, if they are encrypted. If this is not specified, only unencrypted server host keys can be loaded. If the keys passed into server_host_keys are already loaded, this argument is ignored. :param authorized_client_keys: (optional) A list of authorized user and CA public keys which should be trusted for certifcate-based client public key authentication. :param x509_trusted_certs: (optional) A list of certificates which should be trusted for X.509 client certificate authentication. If this argument is explicitly set to ``None``, X.509 client certificate authentication will not be performed. .. note:: X.509 certificates to trust can also be provided through an :ref:`authorized_keys ` file if they are converted into OpenSSH format. This allows their trust to be limited to only specific client IPs or user names and allows SSH functions to be restricted when these certificates are used. :param x509_trusted_cert_paths: (optional) A list of path names to "hash directories" containing certificates which should be trusted for X.509 client certificate authentication. Each certificate should be in a separate file with a name of the form *hash.number*, where *hash* is the OpenSSL hash value of the certificate subject name and *number* is an integer counting up from zero if multiple certificates have the same hash. :param x509_purposes: (optional) A list of purposes allowed in the ExtendedKeyUsage of a certificate used for X.509 client certificate authentication, defulting to 'secureShellClient'. If this argument is explicitly set to ``None``, the client certificate's ExtendedKeyUsage will not be checked. :param str gss_host: (optional) The principal name to use for the host in GSS key exchange and authentication. If not specified, the value returned by :func:`socket.gethostname` will be used if it is a fully qualified name. Otherwise, the value used by :func:`socket.getfqdn` will be used. If this argument is explicitly set to ``None``, GSS key exchange and authentication will not be performed. :param bool allow_pty: (optional) Whether or not to allow allocation of a pseudo-tty in sessions, defaulting to ``True`` :param bool line_editor: (optional) Whether or not to enable input line editing on sessions which have a pseudo-tty allocated, defaulting to ``True`` :param bool line_history: (int) The number of lines of input line history to store in the line editor when it is enabled, defaulting to 1000 :param bool x11_forwarding: (optional) Whether or not to allow forwarding of X11 connections back to the client when the client supports it, defaulting to ``False`` :param str x11_auth_path: (optional) The path to the Xauthority file to write X11 authentication data to, defaulting to the value in the environment variable ``XAUTHORITY`` or the file ``.Xauthority`` in the user's home directory if that's not set :param bool agent_forwarding: (optional) Whether or not to allow forwarding of ssh-agent requests back to the client when the client supports it, defaulting to ``True`` :param callable process_factory: (optional) A callable or coroutine handler function which takes an AsyncSSH :class:`SSHServerProcess` argument that will be called each time a new shell, exec, or subsystem other than SFTP is requested by the client. If set, this takes precedence over the ``session_factory`` argument. :param callable session_factory: (optional) A callable or coroutine handler function which takes AsyncSSH stream objects for stdin, stdout, and stderr that will be called each time a new shell, exec, or subsystem other than SFTP is requested by the client. If not specified, sessions are rejected by default unless the :meth:`session_requested() ` method is overridden on the :class:`SSHServer` object returned by ``server_factory`` to make this decision. :param str session_encoding: (optional) The Unicode encoding to use for data exchanged on sessions on this server, defaulting to UTF-8 (ISO 10646) format. If ``None`` is passed in, the application can send and receive raw bytes. :param callable sftp_factory: (optional) A callable which returns an :class:`SFTPServer` object that will be created each time an SFTP session is requested by the client, or ``True`` to use the base :class:`SFTPServer` class to handle SFTP requests. If not specified, SFTP sessions are rejected by default. :param bool allow_scp: (optional) Whether or not to allow incoming scp requests to be accepted. This option can only be used in conjunction with ``sftp_factory``. If not specified, scp requests will be passed as regular commands to the ``process_factory`` or ``session_factory``. to the client when the client supports it, defaulting to ``True`` :param int window: (optional) The receive window size for sessions on this server :param int max_pktsize: (optional) The maximum packet size for sessions on this server :param str server_version: (optional) An ASCII string to advertise to SSH clients as the version of this server, defaulting to ``AsyncSSH`` and its version number. :param kex_algs: (optional) A list of allowed key exchange algorithms in the SSH handshake, taken from :ref:`key exchange algorithms ` :param encryption_algs: (optional) A list of encryption algorithms to use during the SSH handshake, taken from :ref:`encryption algorithms ` :param mac_algs: (optional) A list of MAC algorithms to use during the SSH handshake, taken from :ref:`MAC algorithms ` :param compression_algs: (optional) A list of compression algorithms to use during the SSH handshake, taken from :ref:`compression algorithms `, or ``None`` to disable compression :param signature_algs: (optional) A list of public key signature algorithms to use during the SSH handshake, taken from :ref:`signature algorithms ` :param int rekey_bytes: (optional) The number of bytes which can be sent before the SSH session key is renegotiated, defaulting to 1 GB :param int rekey_seconds: (optional) The maximum time in seconds before the SSH session key is renegotiated, defaulting to 1 hour :param int login_timeout: (optional) The maximum time in seconds allowed for authentication to complete, defaulting to 2 minutes :type family: ``socket.AF_UNSPEC``, ``socket.AF_INET``, or ``socket.AF_INET6`` :type flags: flags to pass to :meth:`getaddrinfo() ` :type server_host_keys: *see* :ref:`SpecifyingPrivateKeys` :type authorized_client_keys: *see* :ref:`SpecifyingAuthorizedKeys` :type x509_trusted_certs: *see* :ref:`SpecifyingCertificates` :type x509_trusted_cert_paths: list of str :type x509_purposes: *see* :ref:`SpecifyingX509Purposes` :type kex_algs: list of str :type encryption_algs: list of str :type mac_algs: list of str :type compression_algs: list of str :type signature_algs: list of str :returns: :class:`asyncio.AbstractServer` """ def conn_factory(): """Return an SSH server connection handler""" return SSHServerConnection(server_factory, loop, server_version, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, rekey_bytes, rekey_seconds, server_host_keys, authorized_client_keys, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, gss_host, allow_pty, line_editor, line_history, x11_forwarding, x11_auth_path, agent_forwarding, process_factory, session_factory, session_encoding, sftp_factory, allow_scp, window, max_pktsize, login_timeout) if not server_factory: server_factory = SSHServer if sftp_factory is True: sftp_factory = SFTPServer if not loop: loop = asyncio.get_event_loop() server_version = _validate_version(server_version) if gss_host is (): gss_host = socket.gethostname() if '.' not in gss_host: gss_host = socket.getfqdn() kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs = \ _validate_algs(kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, x509_trusted_certs is not None) server_keys = load_keypairs(server_host_keys, passphrase) if not server_keys and not gss_host: raise ValueError('No server host keys provided') server_host_keys = OrderedDict() for keypair in server_keys: for alg in keypair.host_key_algorithms: if alg in server_host_keys: raise ValueError('Multiple keys of type %s found' % alg.decode('ascii')) server_host_keys[alg] = keypair if isinstance(authorized_client_keys, str): authorized_client_keys = read_authorized_keys(authorized_client_keys) if x509_trusted_certs is not None: x509_trusted_certs = load_certificates(x509_trusted_certs) return (yield from loop.create_server(conn_factory, host, port, family=family, flags=flags, backlog=backlog, reuse_address=reuse_address)) @async_context_manager def connect(host, port=_DEFAULT_PORT, **kwargs): """Make an SSH client connection This function is a coroutine wrapper around :func:`create_connection` which can be used when a custom SSHClient instance is not needed. It takes all the same arguments as :func:`create_connection` except for ``client_factory`` and returns only the :class:`SSHClientConnection` object rather than a tuple of an :class:`SSHClientConnection` and :class:`SSHClient`. When using this call, the following restrictions apply: 1. No callbacks are called when the connection is successfully opened, when it is closed, or when authentication completes. 2. Any authentication information must be provided as arguments to this call, as any authentication callbacks will deny other authentication attempts. Also, authentication banner information will be ignored. 3. Any debug messages sent by the server will be ignored. """ conn, _ = yield from create_connection(None, host, port, **kwargs) return conn @asyncio.coroutine def listen(host=None, port=_DEFAULT_PORT, **kwargs): """Start an SSH server This function is a coroutine wrapper around :func:`create_server` which can be used when a custom SSHServer instance is not needed. It takes all the same arguments as :func:`create_server` except for ``server_factory``. When using this call, the following restrictions apply: 1. No callbacks are called when a new connection arrives, when a connection is closed, or when authentication completes. 2. Any authentication information must be provided as arguments to this call, as any authentication callbacks will deny other authentication attempts. Currently, this allows only public key authentication to be used, by passing in the ``authorized_client_keys`` argument. 3. Only handlers using the streams API are supported and the same handlers must be used for all clients. These handlers must be provided in the ``process_factory``, ``session_factory``, and ``sftp_factory`` arguments to this call. 4. Any debug messages sent by the client will be ignored. """ return (yield from create_server(None, host, port, **kwargs)) asyncssh-1.11.1/asyncssh/constants.py000066400000000000000000000167141320320510200176440ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH constants""" # pylint: disable=bad-whitespace # Default language for error messages DEFAULT_LANG = 'en-US' # SSH message codes MSG_DISCONNECT = 1 MSG_IGNORE = 2 MSG_UNIMPLEMENTED = 3 MSG_DEBUG = 4 MSG_SERVICE_REQUEST = 5 MSG_SERVICE_ACCEPT = 6 MSG_EXT_INFO = 7 MSG_KEXINIT = 20 MSG_NEWKEYS = 21 MSG_KEX_FIRST = 30 MSG_KEX_LAST = 49 MSG_USERAUTH_REQUEST = 50 MSG_USERAUTH_FAILURE = 51 MSG_USERAUTH_SUCCESS = 52 MSG_USERAUTH_BANNER = 53 MSG_USERAUTH_FIRST = 60 MSG_USERAUTH_LAST = 79 MSG_GLOBAL_REQUEST = 80 MSG_REQUEST_SUCCESS = 81 MSG_REQUEST_FAILURE = 82 MSG_CHANNEL_OPEN = 90 MSG_CHANNEL_OPEN_CONFIRMATION = 91 MSG_CHANNEL_OPEN_FAILURE = 92 MSG_CHANNEL_WINDOW_ADJUST = 93 MSG_CHANNEL_DATA = 94 MSG_CHANNEL_EXTENDED_DATA = 95 MSG_CHANNEL_EOF = 96 MSG_CHANNEL_CLOSE = 97 MSG_CHANNEL_REQUEST = 98 MSG_CHANNEL_SUCCESS = 99 MSG_CHANNEL_FAILURE = 100 # SSH disconnect reason codes DISC_HOST_NOT_ALLOWED_TO_CONNECT = 1 DISC_PROTOCOL_ERROR = 2 DISC_KEY_EXCHANGE_FAILED = 3 DISC_RESERVED = 4 DISC_MAC_ERROR = 5 DISC_COMPRESSION_ERROR = 6 DISC_SERVICE_NOT_AVAILABLE = 7 DISC_PROTOCOL_VERSION_NOT_SUPPORTED = 8 DISC_HOST_KEY_NOT_VERIFYABLE = 9 DISC_CONNECTION_LOST = 10 DISC_BY_APPLICATION = 11 DISC_TOO_MANY_CONNECTIONS = 12 DISC_AUTH_CANCELLED_BY_USER = 13 DISC_NO_MORE_AUTH_METHODS_AVAILABLE = 14 DISC_ILLEGAL_USER_NAME = 15 # SSH channel open failure reason codes OPEN_ADMINISTRATIVELY_PROHIBITED = 1 OPEN_CONNECT_FAILED = 2 OPEN_UNKNOWN_CHANNEL_TYPE = 3 OPEN_RESOURCE_SHORTAGE = 4 # Internal failure reason codes OPEN_REQUEST_X11_FORWARDING_FAILED = 0xfffffffd OPEN_REQUEST_PTY_FAILED = 0xfffffffe OPEN_REQUEST_SESSION_FAILED = 0xffffffff # SSH file transfer packet types FXP_INIT = 1 FXP_VERSION = 2 FXP_OPEN = 3 FXP_CLOSE = 4 FXP_READ = 5 FXP_WRITE = 6 FXP_LSTAT = 7 FXP_FSTAT = 8 FXP_SETSTAT = 9 FXP_FSETSTAT = 10 FXP_OPENDIR = 11 FXP_READDIR = 12 FXP_REMOVE = 13 FXP_MKDIR = 14 FXP_RMDIR = 15 FXP_REALPATH = 16 FXP_STAT = 17 FXP_RENAME = 18 FXP_READLINK = 19 FXP_SYMLINK = 20 FXP_STATUS = 101 FXP_HANDLE = 102 FXP_DATA = 103 FXP_NAME = 104 FXP_ATTRS = 105 FXP_EXTENDED = 200 FXP_EXTENDED_REPLY = 201 # SSH file transfer open flags FXF_READ = 0x00000001 FXF_WRITE = 0x00000002 FXF_APPEND = 0x00000004 FXF_CREAT = 0x00000008 FXF_TRUNC = 0x00000010 FXF_EXCL = 0x00000020 # SSH file transfer attribute flags FILEXFER_ATTR_SIZE = 0x00000001 FILEXFER_ATTR_UIDGID = 0x00000002 FILEXFER_ATTR_PERMISSIONS = 0x00000004 FILEXFER_ATTR_ACMODTIME = 0x00000008 FILEXFER_ATTR_EXTENDED = 0x80000000 FILEXFER_ATTR_UNDEFINED = 0x7ffffff0 # OpenSSH statvfs attribute flags FXE_STATVFS_ST_RDONLY = 0x1 FXE_STATVFS_ST_NOSUID = 0x2 # SSH file transfer error codes FX_OK = 0 FX_EOF = 1 FX_NO_SUCH_FILE = 2 FX_PERMISSION_DENIED = 3 FX_FAILURE = 4 FX_BAD_MESSAGE = 5 FX_NO_CONNECTION = 6 FX_CONNECTION_LOST = 7 FX_OP_UNSUPPORTED = 8 # SSH channel data type codes EXTENDED_DATA_STDERR = 1 # SSH pty mode opcodes PTY_OP_END = 0 PTY_VINTR = 1 PTY_VQUIT = 2 PTY_VERASE = 3 PTY_VKILL = 4 PTY_VEOF = 5 PTY_VEOL = 6 PTY_VEOL2 = 7 PTY_VSTART = 8 PTY_VSTOP = 9 PTY_VSUSP = 10 PTY_VDSUSP = 11 PTY_VREPRINT = 12 PTY_WERASE = 13 PTY_VLNEXT = 14 PTY_VFLUSH = 15 PTY_VSWTCH = 16 PTY_VSTATUS = 17 PTY_VDISCARD = 18 PTY_IGNPAR = 30 PTY_PARMRK = 31 PTY_INPCK = 32 PTY_ISTRIP = 33 PTY_INLCR = 34 PTY_IGNCR = 35 PTY_ICRNL = 36 PTY_IUCLC = 37 PTY_IXON = 38 PTY_IXANY = 39 PTY_IXOFF = 40 PTY_IMAXBEL = 41 PTY_IUTF8 = 42 PTY_ISIG = 50 PTY_ICANON = 51 PTY_XCASE = 52 PTY_ECHO = 53 PTY_ECHOE = 54 PTY_ECHOK = 55 PTY_ECHONL = 56 PTY_NOFLSH = 57 PTY_TOSTOP = 58 PTY_IEXTEN = 59 PTY_ECHOCTL = 60 PTY_ECHOKE = 61 PTY_PENDIN = 62 PTY_OPOST = 70 PTY_OLCUC = 71 PTY_ONLCR = 72 PTY_OCRNL = 73 PTY_ONOCR = 74 PTY_ONLRET = 75 PTY_CS7 = 90 PTY_CS8 = 91 PTY_PARENB = 92 PTY_PARODD = 93 PTY_OP_ISPEED = 128 PTY_OP_OSPEED = 129 PTY_OP_RESERVED = 160 asyncssh-1.11.1/asyncssh/crypto/000077500000000000000000000000001320320510200165655ustar00rootroot00000000000000asyncssh-1.11.1/asyncssh/crypto/__init__.py000066400000000000000000000027521320320510200207040ustar00rootroot00000000000000# Copyright (c) 2014-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim for accessing cryptographic primitives needed by asyncssh""" from .cipher import register_cipher, lookup_cipher from .ec import lookup_ec_curve_by_params # Import PyCA versions of DSA, ECDSA, RSA, and PBKDF2 from .pyca.dsa import DSAPrivateKey, DSAPublicKey from .pyca.ec import ECDSAPrivateKey, ECDSAPublicKey, ECDH from .pyca.rsa import RSAPrivateKey, RSAPublicKey from .pyca.kdf import pbkdf2_hmac # Import pyca module to get ciphers defined there registered from . import pyca # Import chacha20-poly1305 cipher if available from . import chacha # Import curve25519 DH if available try: from .curve25519 import Curve25519DH except ImportError: # pragma: no cover pass # Import umac cryptographic hash if available try: from .umac import umac32, umac64, umac96, umac128 except (ImportError, AttributeError, OSError): # pragma: no cover pass # Import X.509 certificate support if available try: from .pyca.x509 import X509Name, X509NamePattern from .pyca.x509 import generate_x509_certificate, import_x509_certificate except ImportError: # pragma: no cover pass asyncssh-1.11.1/asyncssh/crypto/chacha.py000066400000000000000000000104271320320510200203520ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Chacha20-Poly1305 symmetric encryption handler""" import ctypes from .cipher import register_cipher class _Chacha20Poly1305Cipher: """Handler for Chacha20-Poly1305 symmetric encryption""" block_size = 1 iv_size = 0 def __init__(self, key): if len(key) != 2 * _CHACHA20_KEYBYTES: raise ValueError('Invalid chacha20-poly1305 key size') self._key = key[:_CHACHA20_KEYBYTES] self._adkey = key[_CHACHA20_KEYBYTES:] @classmethod def new(cls, key, iv=None, initial_bytes=0): """Construct a new chacha20-poly1305 cipher object""" # pylint: disable=unused-argument return cls(key) def _crypt(self, key, data, nonce, ctr=0): """Encrypt/decrypt a block of data""" # pylint: disable=no-self-use datalen = len(data) result = ctypes.create_string_buffer(datalen) datalen = ctypes.c_ulonglong(datalen) ctr = ctypes.c_ulonglong(ctr) if _chacha20_xor_ic(result, data, datalen, nonce, ctr, key) != 0: raise ValueError('Chacha encryption failed') # pragma: no cover return result.raw def _polykey(self, nonce): """Generate a poly1305 key""" polykey = ctypes.create_string_buffer(_POLY1305_KEYBYTES) polykeylen = ctypes.c_ulonglong(_POLY1305_KEYBYTES) if _chacha20(polykey, polykeylen, nonce, self._key) != 0: raise ValueError('Poly1305 key gen failed') # pragma: no cover return polykey def _compute_tag(self, data, nonce): """Compute a poly1305 tag for a block of data""" tag = ctypes.create_string_buffer(_POLY1305_BYTES) datalen = ctypes.c_ulonglong(len(data)) polykey = self._polykey(nonce) if _poly1305(tag, data, datalen, polykey) != 0: raise ValueError('Poly1305 tag gen failed') # pragma: no cover return tag.raw def _verify_tag(self, data, nonce, tag): """Verify a poly1305 tag on a block of data""" datalen = ctypes.c_ulonglong(len(data)) polykey = self._polykey(nonce) return _poly1305_verify(tag, data, datalen, polykey) == 0 def crypt_len(self, data, nonce): """Encrypt/decrypt an SSH packet length value""" if len(nonce) != _CHACHA20_NONCEBYTES: raise ValueError('Invalid chacha20-poly1305 nonce size') return self._crypt(self._adkey, data, nonce) def encrypt_and_sign(self, header, data, nonce): """Encrypt and sign a block of data""" if len(nonce) != _CHACHA20_NONCEBYTES: raise ValueError('Invalid chacha20-poly1305 nonce size') ciphertext = self._crypt(self._key, data, nonce, 1) tag = self._compute_tag(header + ciphertext, nonce) return ciphertext, tag def verify_and_decrypt(self, header, data, nonce, tag): """Verify the signature of and decrypt a block of data""" if len(nonce) != _CHACHA20_NONCEBYTES: raise ValueError('Invalid chacha20-poly1305 nonce size') if self._verify_tag(header + data, nonce, tag): plaintext = self._crypt(self._key, data, nonce, 1) else: plaintext = None return plaintext try: # pylint: disable=wrong-import-position,wrong-import-order from libnacl import nacl _CHACHA20_KEYBYTES = nacl.crypto_stream_chacha20_keybytes() _CHACHA20_NONCEBYTES = nacl.crypto_stream_chacha20_noncebytes() _chacha20 = nacl.crypto_stream_chacha20 _chacha20_xor_ic = nacl.crypto_stream_chacha20_xor_ic _POLY1305_BYTES = nacl.crypto_onetimeauth_poly1305_bytes() _POLY1305_KEYBYTES = nacl.crypto_onetimeauth_poly1305_keybytes() _poly1305 = nacl.crypto_onetimeauth_poly1305 _poly1305_verify = nacl.crypto_onetimeauth_poly1305_verify except (ImportError, OSError, AttributeError): # pragma: no cover pass else: register_cipher('chacha20-poly1305', 'chacha', _Chacha20Poly1305Cipher) asyncssh-1.11.1/asyncssh/crypto/cipher.py000066400000000000000000000020051320320510200204060ustar00rootroot00000000000000# Copyright (c) 2014-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim for accessing symmetric ciphers needed by asyncssh""" _ciphers = {} def register_cipher(cipher_name, mode_name, cipher): """Register a symmetric cipher If multiple modules try to register the same cipher and mode, the first one to register it is used. """ if (cipher_name, mode_name) not in _ciphers: # pragma: no branch cipher.cipher_name = cipher_name cipher.mode_name = mode_name _ciphers[cipher_name, mode_name] = cipher def lookup_cipher(cipher_name, mode_name): """Look up a symmetric cipher""" return _ciphers.get((cipher_name, mode_name)) asyncssh-1.11.1/asyncssh/crypto/curve25519.py000066400000000000000000000036741320320510200207030ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Curve25519 key exchange handler primitives""" import ctypes import os try: from libnacl import nacl _CURVE25519_BYTES = nacl.crypto_scalarmult_curve25519_bytes() _CURVE25519_SCALARBYTES = nacl.crypto_scalarmult_curve25519_scalarbytes() _curve25519 = nacl.crypto_scalarmult_curve25519 _curve25519_base = nacl.crypto_scalarmult_curve25519_base except (ImportError, OSError, AttributeError): # pragma: no cover pass else: class Curve25519DH: """Curve25519 Diffie Hellman implementation""" def __init__(self): self._private = os.urandom(_CURVE25519_SCALARBYTES) def get_public(self): """Return the public key to send in the handshake""" public = ctypes.create_string_buffer(_CURVE25519_BYTES) if _curve25519_base(public, self._private) != 0: # This error is never returned by libsodium raise ValueError('Curve25519 failed') # pragma: no cover return public.raw def get_shared(self, peer_public): """Return the shared key from the peer's public key""" if len(peer_public) != _CURVE25519_BYTES: raise AssertionError('Invalid curve25519 public key size') shared = ctypes.create_string_buffer(_CURVE25519_BYTES) if _curve25519(shared, self._private, peer_public) != 0: # This error is never returned by libsodium raise ValueError('Curve25519 failed') # pragma: no cover return int.from_bytes(shared.raw, 'big') asyncssh-1.11.1/asyncssh/crypto/ec.py000066400000000000000000000073241320320510200175340ustar00rootroot00000000000000# Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Elliptic curve public key utility functions""" _curve_param_map = {} # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name def register_prime_curve(curve_id, p, a, b, point, n): """Register an elliptic curve prime domain This function registers an elliptic curve prime domain by specifying the SSH identifier for the curve and the set of parameters describing the curve, generator point, and order. This allows EC keys encoded with explicit parameters to be mapped back into their SSH curve IDs. """ _curve_param_map[p, a % p, b % p, point, n] = curve_id def lookup_ec_curve_by_params(p, a, b, point, n): """Look up an elliptic curve by its parameters This function looks up an elliptic curve by its parameters and returns the curve's name. """ try: return _curve_param_map[p, a % p, b % p, point, n] except (KeyError, ValueError): raise ValueError('Unknown elliptic curve parameters') # pylint: disable=line-too-long register_prime_curve(b'nistp521', 6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151, -3, 1093849038073734274511112390766805569936207598951683748994586394495953116150735016013708737573759623248592132296706313309438452531591012912142327488478985984, b'\x04\x00\xc6\x85\x8e\x06\xb7\x04\x04\xe9\xcd\x9e>\xcbf#\x95\xb4B\x9cd\x819\x05?\xb5!\xf8(\xaf`kM=\xba\xa1K^w\xef\xe7Y(\xfe\x1d\xc1\'\xa2\xff\xa8\xde3H\xb3\xc1\x85jB\x9b\xf9~~1\xc2\xe5\xbdf\x01\x189)jx\x9a;\xc0\x04\\\x8a_\xb4,}\x1b\xd9\x98\xf5DIW\x9bDh\x17\xaf\xbd\x17\'>f,\x97\xeer\x99^\xf4&@\xc5P\xb9\x01?\xad\x07a5. # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for accessing cryptographic primitives""" from . import cipher asyncssh-1.11.1/asyncssh/crypto/pyca/cipher.py000066400000000000000000000112311320320510200213430ustar00rootroot00000000000000# Copyright (c) 2014-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for symmetric encryption""" from cryptography.exceptions import InvalidTag from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers.algorithms import AES, ARC4 from cryptography.hazmat.primitives.ciphers.algorithms import Blowfish, CAST5 from cryptography.hazmat.primitives.ciphers.algorithms import TripleDES from cryptography.hazmat.primitives.ciphers.modes import CBC, CTR, GCM from ..cipher import register_cipher # pylint: disable=bad-whitespace _ciphers = {'aes': (AES, {'cbc': CBC, 'ctr': CTR, 'gcm': GCM}), 'arc4': (ARC4, {None: None}), 'blowfish': (Blowfish, {'cbc': CBC}), 'cast': (CAST5, {'cbc': CBC}), 'des': (TripleDES, {'cbc': CBC}), 'des3': (TripleDES, {'cbc': CBC})} # pylint: enable=bad-whitespace class GCMShim: """Shim for PyCA AES-GCM ciphers""" def __init__(self, cipher, block_size, key, iv): self._cipher = cipher self._key = key self._iv = iv self.block_size = block_size def _update_iv(self): """Update the IV after each encrypt/decrypt operation""" invocation = int.from_bytes(self._iv[4:], 'big') invocation = (invocation + 1) & 0xffffffffffffffff self._iv = self._iv[:4] + invocation.to_bytes(8, 'big') def encrypt_and_sign(self, header, data): """Encrypt and sign a block of data""" encryptor = Cipher(self._cipher(self._key), GCM(self._iv), default_backend()).encryptor() if header: encryptor.authenticate_additional_data(header) ciphertext = encryptor.update(data) + encryptor.finalize() self._update_iv() return ciphertext, encryptor.tag def verify_and_decrypt(self, header, data, tag): """Verify the signature of and decrypt a block of data""" decryptor = Cipher(self._cipher(self._key), GCM(self._iv, tag), default_backend()).decryptor() if header: decryptor.authenticate_additional_data(header) try: plaintext = decryptor.update(data) + decryptor.finalize() except InvalidTag: plaintext = None self._update_iv() return plaintext class CipherShim: """Shim for other PyCA ciphers""" def __init__(self, cipher, mode, block_size, key, iv, initial_bytes): if mode: mode = mode(iv) self._cipher = Cipher(cipher(key), mode, default_backend()) self._initial_bytes = initial_bytes self._encryptor = None self._decryptor = None self.block_size = block_size self.mode_name = None # set by register_cipher() def encrypt(self, data): """Encrypt a block of data""" if not self._encryptor: self._encryptor = self._cipher.encryptor() if self._initial_bytes: self._encryptor.update(self._initial_bytes * b'\0') return self._encryptor.update(data) def decrypt(self, data): """Decrypt a block of data""" if not self._decryptor: self._decryptor = self._cipher.decryptor() if self._initial_bytes: self._decryptor.update(self._initial_bytes * b'\0') return self._decryptor.update(data) class CipherFactory: """A factory which returns shims for PyCA symmetric encryption""" def __init__(self, cipher, mode): self._cipher = cipher self._mode = mode self.block_size = 1 if cipher == ARC4 else cipher.block_size // 8 self.iv_size = 12 if mode == GCM else self.block_size def new(self, key, iv=None, initial_bytes=0): """Construct a new symmetric cipher object""" if self._mode == GCM: return GCMShim(self._cipher, self.block_size, key, iv) else: return CipherShim(self._cipher, self._mode, self.block_size, key, iv, initial_bytes) for _cipher_name, (_cipher, _modes) in _ciphers.items(): for _mode_name, _mode in _modes.items(): register_cipher(_cipher_name, _mode_name, CipherFactory(_cipher, _mode)) asyncssh-1.11.1/asyncssh/crypto/pyca/dsa.py000066400000000000000000000060171320320510200206460ustar00rootroot00000000000000# Copyright (c) 2014-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for DSA public and private keys""" from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.hashes import SHA1 from cryptography.hazmat.primitives.asymmetric import dsa from .misc import PyCAKey # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _DSAKey(PyCAKey): """Base class for shim around PyCA for DSA keys""" def __init__(self, pyca_key, params, pub, priv=None): super().__init__(pyca_key) self._params = params self._pub = pub self._priv = priv @property def p(self): """Return the DSA public modulus""" return self._params.p @property def q(self): """Return the DSA sub-group order""" return self._params.q @property def g(self): """Return the DSA generator""" return self._params.g @property def y(self): """Return the DSA public value""" return self._pub.y @property def x(self): """Return the DSA private value""" return self._priv.x if self._priv else None class DSAPrivateKey(_DSAKey): """A shim around PyCA for DSA private keys""" @classmethod def construct(cls, p, q, g, y, x): """Construct a DSA private key""" params = dsa.DSAParameterNumbers(p, q, g) pub = dsa.DSAPublicNumbers(y, params) priv = dsa.DSAPrivateNumbers(x, pub) priv_key = priv.private_key(default_backend()) return cls(priv_key, params, pub, priv) @classmethod def generate(cls, key_size): """Generate a new DSA private key""" priv_key = dsa.generate_private_key(key_size, default_backend()) priv = priv_key.private_numbers() pub = priv.public_numbers params = pub.parameter_numbers return cls(priv_key, params, pub, priv) def sign(self, data): """Sign a block of data""" return self.pyca_key.sign(data, SHA1()) class DSAPublicKey(_DSAKey): """A shim around PyCA for DSA public keys""" @classmethod def construct(cls, p, q, g, y): """Construct a DSA public key""" params = dsa.DSAParameterNumbers(p, q, g) pub = dsa.DSAPublicNumbers(y, params) pub_key = pub.public_key(default_backend()) return cls(pub_key, params, pub) def verify(self, data, sig): """Verify the signature on a block of data""" try: self.pyca_key.verify(sig, data, SHA1()) return True except InvalidSignature: return False asyncssh-1.11.1/asyncssh/crypto/pyca/ec.py000066400000000000000000000123011320320510200204570ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for elliptic curve keys and key exchange""" from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends.openssl import backend from cryptography.hazmat.primitives.hashes import SHA256, SHA384, SHA512 from cryptography.hazmat.primitives.asymmetric import ec from .misc import PyCAKey # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name _curves = {b'nistp256': (ec.SECP256R1, SHA256), b'nistp384': (ec.SECP384R1, SHA384), b'nistp521': (ec.SECP521R1, SHA512)} class _ECKey(PyCAKey): """Base class for shim around PyCA for EC keys""" def __init__(self, pyca_key, curve_id, hash_alg, pub, priv=None): super().__init__(pyca_key) self._curve_id = curve_id self._hash_alg = hash_alg self._pub = pub self._priv = priv @classmethod def lookup_curve(cls, curve_id): """Look up curve and hash algorithm""" try: return _curves[curve_id] except KeyError: # pragma: no cover, other curves not registered raise ValueError('Unknown EC curve %s' % curve_id.decode()) from None @property def curve_id(self): """Return the EC curve name""" return self._curve_id @property def x(self): """Return the EC public x coordinate""" return self._pub.x @property def y(self): """Return the EC public y coordinate""" return self._pub.y @property def d(self): """Return the EC private value as an integer""" return self._priv.private_value if self._priv else None @property def public_value(self): """Return the EC public point value encoded as a byte string""" return self._pub.encode_point() @property def private_value(self): """Return the EC private value encoded as a byte string""" if self._priv: keylen = (self._pub.curve.key_size + 7) // 8 return self._priv.private_value.to_bytes(keylen, 'big') else: return None class ECDSAPrivateKey(_ECKey): """A shim around PyCA for ECDSA private keys""" @classmethod def construct(cls, curve_id, public_value, private_value): """Construct an ECDSA private key""" curve, hash_alg = cls.lookup_curve(curve_id) pub = ec.EllipticCurvePublicNumbers.from_encoded_point(curve(), public_value) priv = ec.EllipticCurvePrivateNumbers(private_value, pub) priv_key = priv.private_key(backend) return cls(priv_key, curve_id, hash_alg, pub, priv) @classmethod def generate(cls, curve_id): """Generate a new ECDSA private key""" curve, hash_alg = cls.lookup_curve(curve_id) priv_key = ec.generate_private_key(curve, backend) priv = priv_key.private_numbers() pub = priv.public_numbers return cls(priv_key, curve_id, hash_alg, pub, priv) def sign(self, data): """Sign a block of data""" return self.pyca_key.sign(data, ec.ECDSA(self._hash_alg())) class ECDSAPublicKey(_ECKey): """A shim around PyCA for ECDSA public keys""" @classmethod def construct(cls, curve_id, public_value): """Construct an ECDSA public key""" curve, hash_alg = cls.lookup_curve(curve_id) pub = ec.EllipticCurvePublicNumbers.from_encoded_point(curve(), public_value) pub_key = pub.public_key(backend) return cls(pub_key, curve_id, hash_alg, pub) def verify(self, data, sig): """Verify the signature on a block of data""" try: self.pyca_key.verify(sig, data, ec.ECDSA(self._hash_alg())) return True except InvalidSignature: return False class ECDH: """A shim around PyCA for ECDH key exchange""" def __init__(self, curve_id): try: curve, _ = _curves[curve_id] except KeyError: # pragma: no cover, other curves not registered raise ValueError('Unknown EC curve %s' % curve_id.decode()) from None self._priv_key = ec.generate_private_key(curve, backend) def get_public(self): """Return the public key to send in the handshake""" pub = self._priv_key.private_numbers().public_numbers return pub.encode_point() def get_shared(self, peer_public): """Return the shared key from the peer's public key""" peer_key = ec.EllipticCurvePublicNumbers.from_encoded_point( self._priv_key.curve, peer_public).public_key(backend) shared_key = self._priv_key.exchange(ec.ECDH(), peer_key) return int.from_bytes(shared_key, 'big') asyncssh-1.11.1/asyncssh/crypto/pyca/kdf.py000066400000000000000000000020161320320510200206360ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for key derivation functions""" from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.hashes import SHA1, SHA224, SHA256 from cryptography.hazmat.primitives.hashes import SHA384, SHA512 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC _hashes = {h.name: h for h in (SHA1, SHA224, SHA256, SHA384, SHA512)} def pbkdf2_hmac(hash_name, passphrase, salt, count, key_size): """A shim around PyCA for PBKDF2 HMAC key derivation""" return PBKDF2HMAC(_hashes[hash_name](), key_size, salt, count, default_backend()).derive(passphrase) asyncssh-1.11.1/asyncssh/crypto/pyca/misc.py000066400000000000000000000013221320320510200210240ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Miscellaneous PyCA utility classes and functions""" class PyCAKey: """Base class for PyCA private/public keys""" def __init__(self, pyca_key): self._pyca_key = pyca_key @property def pyca_key(self): """Return the PyCA object associated with this key""" return self._pyca_key asyncssh-1.11.1/asyncssh/crypto/pyca/rsa.py000066400000000000000000000075411320320510200206670ustar00rootroot00000000000000# Copyright (c) 2014-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for RSA public and private keys""" from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.hashes import SHA1, SHA256, SHA512 from cryptography.hazmat.primitives.asymmetric import rsa from .misc import PyCAKey # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _RSAKey(PyCAKey): """Base class for shim around PyCA for RSA keys""" def __init__(self, pyca_key, pub, priv=None): super().__init__(pyca_key) self._pub = pub self._priv = priv @staticmethod def get_hash(algorithm): """Return hash algorithm to use for signature""" if algorithm == b'rsa-sha2-512': return SHA512() elif algorithm in (b'rsa-sha2-256', b'rsa2048-sha256'): return SHA256() else: return SHA1() @property def n(self): """Return the RSA public modulus""" return self._pub.n @property def e(self): """Return the RSA public exponent""" return self._pub.e @property def d(self): """Return the RSA private exponent""" return self._priv.d if self._priv else None @property def p(self): """Return the RSA first private prime""" return self._priv.p if self._priv else None @property def q(self): """Return the RSA second private prime""" return self._priv.q if self._priv else None @property def dmp1(self): """Return d modulo p-1""" return self._priv.dmp1 if self._priv else None @property def dmq1(self): """Return q modulo p-1""" return self._priv.dmq1 if self._priv else None @property def iqmp(self): """Return the inverse of q modulo p""" return self._priv.iqmp if self._priv else None class RSAPrivateKey(_RSAKey): """A shim around PyCA for RSA private keys""" @classmethod def construct(cls, n, e, d, p, q, dmp1, dmq1, iqmp): """Construct an RSA private key""" pub = rsa.RSAPublicNumbers(e, n) priv = rsa.RSAPrivateNumbers(p, q, d, dmp1, dmq1, iqmp, pub) priv_key = priv.private_key(default_backend()) return cls(priv_key, pub, priv) @classmethod def generate(cls, key_size, exponent): """Generate a new RSA private key""" priv_key = rsa.generate_private_key(exponent, key_size, default_backend()) priv = priv_key.private_numbers() pub = priv.public_numbers return cls(priv_key, pub, priv) def sign(self, data, algorithm): """Sign a block of data""" return self.pyca_key.sign(data, PKCS1v15(), self.get_hash(algorithm)) class RSAPublicKey(_RSAKey): """A shim around PyCA for RSA public keys""" @classmethod def construct(cls, n, e): """Construct an RSA public key""" pub = rsa.RSAPublicNumbers(e, n) pub_key = pub.public_key(default_backend()) return cls(pub_key, pub) def verify(self, data, sig, algorithm): """Verify the signature on a block of data""" try: self.pyca_key.verify(sig, data, PKCS1v15(), self.get_hash(algorithm)) return True except InvalidSignature: return False asyncssh-1.11.1/asyncssh/crypto/pyca/x509.py000066400000000000000000000277531320320510200206160ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA and PyOpenSSL for X.509 certificates""" from datetime import datetime, timezone from ipaddress import ip_address import re import sys from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.hashes import MD5, SHA1, SHA224 from cryptography.hazmat.primitives.hashes import SHA256, SHA384, SHA512 from cryptography.hazmat.primitives.serialization import Encoding from cryptography.hazmat.primitives.serialization import PublicFormat from cryptography import x509 from OpenSSL import crypto from ...asn1 import IA5String, der_decode, der_encode # pylint: disable=bad-whitespace _purpose_to_oid = { 'serverAuth': x509.ExtendedKeyUsageOID.SERVER_AUTH, 'clientAuth': x509.ExtendedKeyUsageOID.CLIENT_AUTH, 'secureShellClient': x509.ObjectIdentifier('1.3.6.1.5.5.7.3.21'), 'secureShellServer': x509.ObjectIdentifier('1.3.6.1.5.5.7.3.22')} # pylint: enable=bad-whitespace _purpose_any = x509.ObjectIdentifier('2.5.29.37.0') _hashes = {h.name: h for h in (MD5, SHA1, SHA224, SHA256, SHA384, SHA512)} _nscomment_oid = x509.ObjectIdentifier('2.16.840.1.113730.1.13') if sys.platform == 'win32': # pragma: no cover # Windows' datetime.max is year 9999, but timestamps that large don't work _gen_time_max = datetime(2999, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc).timestamp() - 1 else: _gen_time_max = datetime.max.replace(tzinfo=timezone.utc).timestamp() - 1 def _to_generalized_time(t): """Convert a timestamp value to a datetime""" return datetime.utcfromtimestamp(max(1, min(t, _gen_time_max))) def _to_purpose_oids(purposes): """Convert a list of purposes to purpose OIDs""" if isinstance(purposes, str): purposes = [p.strip() for p in purposes.split(',')] if not purposes or 'any' in purposes or _purpose_any in purposes: purposes = None else: purposes = set(_purpose_to_oid.get(p) or x509.ObjectIdentifier(p) for p in purposes) return purposes def _encode_user_principals(principals): """Encode user principals as e-mail addresses""" if isinstance(principals, str): principals = [p.strip() for p in principals.split(',')] return [x509.RFC822Name(name) for name in principals] def _encode_host_principals(principals): """Encode host principals as DNS names or IP addresses""" def _encode_host(name): """Encode a host principal as a DNS name or IP address""" try: return x509.IPAddress(ip_address(name)) except ValueError: return x509.DNSName(name) if isinstance(principals, str): principals = [p.strip() for p in principals.split(',')] return [_encode_host(name) for name in principals] class X509Name(x509.Name): """A shim around PyCA for X.509 distinguished names""" _escape = re.compile(r'([,+\\])') _unescape = re.compile(r'\\([,+\\])') _split_rdn = re.compile(r'(?:[^+\\]+|\\.)+') _split_name = re.compile(r'(?:[^,\\]+|\\.)+') # pylint: disable=bad-whitespace _attrs = ( ('C', x509.NameOID.COUNTRY_NAME), ('ST', x509.NameOID.STATE_OR_PROVINCE_NAME), ('L', x509.NameOID.LOCALITY_NAME), ('O', x509.NameOID.ORGANIZATION_NAME), ('OU', x509.NameOID.ORGANIZATIONAL_UNIT_NAME), ('CN', x509.NameOID.COMMON_NAME), ('DC', x509.NameOID.DOMAIN_COMPONENT)) # pylint: enable=bad-whitespace _to_oid = dict((k, v) for k, v in _attrs) _from_oid = dict((v, k) for k, v in _attrs) def __init__(self, name): if isinstance(name, str): name = self._parse_name(name) elif isinstance(name, x509.Name): name = name.rdns super().__init__(name) def __str__(self): return ','.join(self._format_rdn(rdn) for rdn in self.rdns) def _format_rdn(self, rdn): """Format an X.509 RelativeDistinguishedName as a string""" return '+'.join(sorted(self._format_attr(nameattr) for nameattr in rdn)) def _format_attr(self, nameattr): """Format an X.509 NameAttribute as a string""" attr = self._from_oid.get(nameattr.oid) or nameattr.oid.dotted_string return attr + '=' + self._escape.sub(r'\\\1', nameattr.value) def _parse_name(self, name): """Parse an X.509 distinguished name""" return (self._parse_rdn(rdn) for rdn in self._split_name.findall(name)) def _parse_rdn(self, rdn): """Parse an X.509 relative distinguished name""" return x509.RelativeDistinguishedName( self._parse_nameattr(av) for av in self._split_rdn.findall(rdn)) def _parse_nameattr(self, av): """Parse an X.509 name attribute/value pair""" try: attr, value = av.split('=', 1) except ValueError: raise ValueError('Invalid X.509 name attribute: ' + av) from None try: attr = attr.strip() oid = self._to_oid.get(attr) or x509.ObjectIdentifier(attr) except ValueError: raise ValueError('Unknown X.509 attribute: ' + attr) from None return x509.NameAttribute(oid, self._unescape.sub(r'\1', value)) class X509NamePattern: """Match X.509 distinguished names""" def __init__(self, pattern): if pattern.endswith(',*'): self._pattern = X509Name(pattern[:-2]) self._prefix_len = len(self._pattern.rdns) else: self._pattern = X509Name(pattern) self._prefix_len = None def __eq__(self, other): # This isn't protected access - both objects are _RSAKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._pattern == other._pattern and self._prefix_len == other._prefix_len) def __hash__(self): return hash((self._pattern, self._prefix_len)) def matches(self, name): """Return whether an X.509 name matches this pattern""" return self._pattern.rdns == name.rdns[:self._prefix_len] class X509Certificate: """A shim around PyCA and PyOpenSSL for X.509 certificates""" def __init__(self, cert, data): self.data = data self.subject = X509Name(cert.subject) self.issuer = X509Name(cert.issuer) self.key_data = cert.public_key().public_bytes( Encoding.DER, PublicFormat.SubjectPublicKeyInfo) self.openssl_cert = crypto.X509.from_cryptography(cert) self.subject_hash = hex(self.openssl_cert.get_subject().hash())[2:] self.issuer_hash = hex(self.openssl_cert.get_issuer().hash())[2:] try: self.purposes = set(cert.extensions.get_extension_for_class( x509.ExtendedKeyUsage).value) except x509.ExtensionNotFound: self.purposes = None try: sans = cert.extensions.get_extension_for_class( x509.SubjectAlternativeName).value self.user_principals = sans.get_values_for_type(x509.RFC822Name) self.host_principals = sans.get_values_for_type(x509.DNSName) + \ [str(ip) for ip in sans.get_values_for_type(x509.IPAddress)] except x509.ExtensionNotFound: cn = cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME) principals = [attr.value for attr in cn] self.user_principals = principals self.host_principals = principals try: comment = cert.extensions.get_extension_for_oid(_nscomment_oid) self.comment = str(der_decode(comment.value.value)) except x509.ExtensionNotFound: self.comment = None except UnicodeDecodeError: raise ValueError('Invalid character in comment') from None def __eq__(self, other): return isinstance(other, type(self)) and self.data == other.data def __hash__(self): return hash(self.data) def validate(self, trust_store, purposes, user_principal, host_principal): """Validate an X.509 certificate""" purposes = _to_purpose_oids(purposes) if purposes and self.purposes and not purposes & self.purposes: raise ValueError('Certificate purpose mismatch') if user_principal and user_principal not in self.user_principals: raise ValueError('Certificate user principal mismatch') if host_principal and host_principal not in self.host_principals: raise ValueError('Certificate host principal mismatch') x509_store = crypto.X509Store() for c in trust_store: x509_store.add_cert(c.openssl_cert) try: x509_ctx = crypto.X509StoreContext(x509_store, self.openssl_cert) x509_ctx.verify_certificate() except crypto.X509StoreContextError as exc: raise ValueError(str(exc)) from None def generate_x509_certificate(signing_key, key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_alg, comment): """Generate a new X.509 certificate""" subject = X509Name(subject) issuer = X509Name(issuer) if issuer else subject valid_after = _to_generalized_time(valid_after) valid_before = _to_generalized_time(valid_before) purposes = _to_purpose_oids(purposes) self_signed = subject == issuer if serial is None: serial = x509.random_serial_number() builder = x509.CertificateBuilder() builder = builder.subject_name(subject) builder = builder.issuer_name(issuer) builder = builder.serial_number(serial) builder = builder.not_valid_before(valid_after) builder = builder.not_valid_after(valid_before) builder = builder.public_key(key.pyca_key) if ca: basic_constraints = x509.BasicConstraints( ca=True, path_length=ca_path_len) key_usage = x509.KeyUsage( digital_signature=True, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False) else: basic_constraints = x509.BasicConstraints(ca=False, path_length=None) key_usage = x509.KeyUsage( digital_signature=True, content_commitment=False, key_encipherment=True, data_encipherment=False, key_agreement=False, key_cert_sign=self_signed, crl_sign=False, encipher_only=False, decipher_only=False) builder = builder.add_extension(basic_constraints, critical=True) builder = builder.add_extension(key_usage, critical=True) if purposes: builder = builder.add_extension( x509.ExtendedKeyUsage(purposes), critical=False) sans = (_encode_user_principals(user_principals) + _encode_host_principals(host_principals)) if sans: builder = builder.add_extension( x509.SubjectAlternativeName(sans), critical=False) if comment: comment = der_encode(IA5String(comment)) builder = builder.add_extension( x509.UnrecognizedExtension(_nscomment_oid, comment), critical=False) try: hash_alg = _hashes[hash_alg]() except KeyError: raise ValueError('Unknown hash algorithm') from None cert = builder.sign(signing_key.pyca_key, hash_alg, default_backend()) data = cert.public_bytes(Encoding.DER) return X509Certificate(cert, data) def import_x509_certificate(data): """Construct an X.509 certificate from DER data""" cert = x509.load_der_x509_certificate(data, default_backend()) return X509Certificate(cert, data) asyncssh-1.11.1/asyncssh/crypto/umac.py000066400000000000000000000063161320320510200200720ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """UMAC cryptographic hash (RFC 4418) wrapper for Nettle library""" import binascii import ctypes import ctypes.util import sys _UMAC_BLOCK_SIZE = 1024 _UMAC_DEFAULT_CTX_SIZE = 4096 def __build_umac(size): """Function to build UMAC wrapper for a specific digest size""" _name = 'umac%d' % size _prefix = 'nettle_%s_' % _name try: _context_size = getattr(_nettle, _prefix + '_ctx_size')() except AttributeError: _context_size = _UMAC_DEFAULT_CTX_SIZE _set_key = getattr(_nettle, _prefix + 'set_key') _set_nonce = getattr(_nettle, _prefix + 'set_nonce') _update = getattr(_nettle, _prefix + 'update') _digest = getattr(_nettle, _prefix + 'digest') class _UMAC: """Wrapper for UMAC cryptographic hash This class supports the cryptographic hash API defined in PEP 452. """ name = _name block_size = _UMAC_BLOCK_SIZE digest_size = size // 8 def __init__(self, ctx, nonce=None, msg=None): self._ctx = ctx if nonce: self.set_nonce(nonce) if msg: self.update(msg) @classmethod def new(cls, key, msg=None, nonce=None): """Construct a new UMAC hash object""" ctx = ctypes.create_string_buffer(_context_size) _set_key(ctx, key) return cls(ctx, nonce, msg) def copy(self): """Return a new hash object with this object's state""" ctx = ctypes.create_string_buffer(self._ctx.raw) return self.__class__(ctx) def set_nonce(self, nonce): """Reset the nonce associated with this object""" _set_nonce(self._ctx, ctypes.c_size_t(len(nonce)), nonce) def update(self, msg): """Add the data in msg to the hash""" _update(self._ctx, ctypes.c_size_t(len(msg)), msg) def digest(self): """Return the hash and increment nonce to begin a new message .. note:: The hash is reset and the nonce is incremented when this function is called. This doesn't match the behavior defined in PEP 452. """ result = ctypes.create_string_buffer(self.digest_size) _digest(self._ctx, ctypes.c_size_t(self.digest_size), result) return result.raw def hexdigest(self): """Return the digest as a string of hexadecimal digits""" return binascii.b2a_hex(self.digest()).decode('ascii') globals()[_name] = _UMAC.new digest_size = None if sys.platform == 'win32': # pragma: no cover _nettle = ctypes.cdll.LoadLibrary('libnettle-6') else: _nettle = ctypes.cdll.LoadLibrary(ctypes.util.find_library('nettle')) for _size in (32, 64, 96, 128): __build_umac(_size) asyncssh-1.11.1/asyncssh/dh.py000066400000000000000000000565241320320510200162260ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH Diffie-Hellman key exchange handlers""" from hashlib import sha1, sha256, sha512 from .constants import DEFAULT_LANG from .constants import DISC_KEY_EXCHANGE_FAILED, DISC_PROTOCOL_ERROR from .gss import GSSError from .kex import Kex, register_kex_alg, register_gss_kex_alg from .logging import logger from .misc import DisconnectError, randrange from .packet import Boolean, Byte, MPInt, String, UInt32 # pylint: disable=bad-whitespace,line-too-long # SSH KEX DH message values MSG_KEXDH_INIT = 30 MSG_KEXDH_REPLY = 31 # SSH KEX DH group exchange message values MSG_KEX_DH_GEX_REQUEST_OLD = 30 MSG_KEX_DH_GEX_GROUP = 31 MSG_KEX_DH_GEX_INIT = 32 MSG_KEX_DH_GEX_REPLY = 33 MSG_KEX_DH_GEX_REQUEST = 34 # SSH KEXGSS message value MSG_KEXGSS_INIT = 30 MSG_KEXGSS_CONTINUE = 31 MSG_KEXGSS_COMPLETE = 32 MSG_KEXGSS_HOSTKEY = 33 MSG_KEXGSS_ERROR = 34 MSG_KEXGSS_GROUPREQ = 40 MSG_KEXGSS_GROUP = 41 # SSH KEX group exchange key sizes KEX_DH_GEX_MIN_SIZE = 1024 KEX_DH_GEX_PREFERRED_SIZE = 2048 KEX_DH_GEX_MAX_SIZE = 8192 # SSH Diffie-Hellman group 1 parameters _group1_g = 2 _group1_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece65381ffffffffffffffff # SSH Diffie-Hellman group 14 parameters _group14_g = 2 _group14_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aacaa68ffffffffffffffff # SSH Diffie-Hellman group 15 parameters _group15_g = 2 _group15_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a93ad2caffffffffffffffff # SSH Diffie-Hellman group 16 parameters _group16_g = 2 _group16_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a92108011a723c12a787e6d788719a10bdba5b2699c327186af4e23c1a946834b6150bda2583e9ca2ad44ce8dbbbc2db04de8ef92e8efc141fbecaa6287c59474e6bc05d99b2964fa090c3a2233ba186515be7ed1f612970cee2d7afb81bdd762170481cd0069127d5b05aa993b4ea988d8fddc186ffb7dc90a6c08f4df435c934063199ffffffffffffffff # SSH Diffie-Hellman group 17 parameters _group17_g = 2 _group17_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a92108011a723c12a787e6d788719a10bdba5b2699c327186af4e23c1a946834b6150bda2583e9ca2ad44ce8dbbbc2db04de8ef92e8efc141fbecaa6287c59474e6bc05d99b2964fa090c3a2233ba186515be7ed1f612970cee2d7afb81bdd762170481cd0069127d5b05aa993b4ea988d8fddc186ffb7dc90a6c08f4df435c93402849236c3fab4d27c7026c1d4dcb2602646dec9751e763dba37bdf8ff9406ad9e530ee5db382f413001aeb06a53ed9027d831179727b0865a8918da3edbebcf9b14ed44ce6cbaced4bb1bdb7f1447e6cc254b332051512bd7af426fb8f401378cd2bf5983ca01c64b92ecf032ea15d1721d03f482d7ce6e74fef6d55e702f46980c82b5a84031900b1c9e59e7c97fbec7e8f323a97a7e36cc88be0f1d45b7ff585ac54bd407b22b4154aacc8f6d7ebf48e1d814cc5ed20f8037e0a79715eef29be32806a1d58bb7c5da76f550aa3d8a1fbff0eb19ccb1a313d55cda56c9ec2ef29632387fe8d76e3c0468043e8f663f4860ee12bf2d5b0b7474d6e694f91e6dcc4024ffffffffffffffff # SSH Diffie-Hellman group 18 parameters _group18_g = 2 _group18_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a92108011a723c12a787e6d788719a10bdba5b2699c327186af4e23c1a946834b6150bda2583e9ca2ad44ce8dbbbc2db04de8ef92e8efc141fbecaa6287c59474e6bc05d99b2964fa090c3a2233ba186515be7ed1f612970cee2d7afb81bdd762170481cd0069127d5b05aa993b4ea988d8fddc186ffb7dc90a6c08f4df435c93402849236c3fab4d27c7026c1d4dcb2602646dec9751e763dba37bdf8ff9406ad9e530ee5db382f413001aeb06a53ed9027d831179727b0865a8918da3edbebcf9b14ed44ce6cbaced4bb1bdb7f1447e6cc254b332051512bd7af426fb8f401378cd2bf5983ca01c64b92ecf032ea15d1721d03f482d7ce6e74fef6d55e702f46980c82b5a84031900b1c9e59e7c97fbec7e8f323a97a7e36cc88be0f1d45b7ff585ac54bd407b22b4154aacc8f6d7ebf48e1d814cc5ed20f8037e0a79715eef29be32806a1d58bb7c5da76f550aa3d8a1fbff0eb19ccb1a313d55cda56c9ec2ef29632387fe8d76e3c0468043e8f663f4860ee12bf2d5b0b7474d6e694f91e6dbe115974a3926f12fee5e438777cb6a932df8cd8bec4d073b931ba3bc832b68d9dd300741fa7bf8afc47ed2576f6936ba424663aab639c5ae4f5683423b4742bf1c978238f16cbe39d652de3fdb8befc848ad922222e04a4037c0713eb57a81a23f0c73473fc646cea306b4bcbc8862f8385ddfa9d4b7fa2c087e879683303ed5bdd3a062b3cf5b3a278a66d2a13f83f44f82ddf310ee074ab6a364597e899a0255dc164f31cc50846851df9ab48195ded7ea1b1d510bd7ee74d73faf36bc31ecfa268359046f4eb879f924009438b481c6cd7889a002ed5ee382bc9190da6fc026e479558e4475677e9aa9e3050e2765694dfc81f56e880b96e7160c980dd98edd3dfffffffffffffffff _dh_gex_groups = ((1024, _group1_g, _group1_p), (2048, _group14_g, _group14_p), (3072, _group15_g, _group15_p), (4096, _group16_g, _group16_p), (6144, _group17_g, _group17_p), (8192, _group18_g, _group18_p)) # pylint: enable=bad-whitespace,line-too-long # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _KexDHBase(Kex): """Abstract base class for Diffie-Hellman key exchange""" _init_type = None _reply_type = None def __init__(self, alg, conn, hash_alg): super().__init__(alg, conn, hash_alg) self._g = None self._p = None self._q = None self._x = None self._e = None self._f = None self._gex_data = b'' def _init_group(self, g, p): """Initialize DH group parameters""" self._g = g self._p = p self._q = (p - 1) // 2 def _compute_hash(self, host_key_data, k): """Compute a hash of key information associated with the connection""" hash_obj = self._hash_alg() hash_obj.update(self._conn.get_hash_prefix()) hash_obj.update(String(host_key_data)) hash_obj.update(self._gex_data) hash_obj.update(MPInt(self._e)) hash_obj.update(MPInt(self._f)) hash_obj.update(MPInt(k)) return hash_obj.digest() def _send_init(self): """Send a DH init message""" self._conn.send_packet(Byte(self._init_type), MPInt(self._e)) def _send_reply(self, key_data, sig): """Send a DH reply message""" self._conn.send_packet(Byte(self._reply_type), String(key_data), MPInt(self._f), String(sig)) def _perform_init(self): """Compute e and send init message""" self._x = randrange(2, self._q) self._e = pow(self._g, self._x, self._p) self._send_init() def _perform_reply(self, key, key_data): """Compute f and send reply message""" if not 1 <= self._e < self._p: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Kex DH e out of range') y = randrange(2, self._q) self._f = pow(self._g, y, self._p) k = pow(self._e, y, self._p) if k < 1: # pragma: no cover, shouldn't be possible with valid p raise DisconnectError(DISC_PROTOCOL_ERROR, 'Kex DH k out of range') h = self._compute_hash(key_data, k) self._send_reply(key_data, key.sign(h)) self._conn.send_newkeys(k, h) def _verify_reply(self, key, key_data, sig): """Verify a DH reply message""" if not 1 <= self._f < self._p: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Kex DH f out of range') k = pow(self._f, self._x, self._p) if k < 1: # pragma: no cover, shouldn't be possible with valid p raise DisconnectError(DISC_PROTOCOL_ERROR, 'Kex DH k out of range') h = self._compute_hash(key_data, k) if not key.verify(h, sig): raise DisconnectError(DISC_KEY_EXCHANGE_FAILED, 'Key exchange hash mismatch') self._conn.send_newkeys(k, h) def _process_init(self, pkttype, packet): """Process a DH init message""" # pylint: disable=unused-argument if self._conn.is_client() or not self._p: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kex init msg') self._e = packet.get_mpint() packet.check_end() host_key = self._conn.get_server_host_key() self._perform_reply(host_key, host_key.public_data) def _process_reply(self, pkttype, packet): """Process a DH reply message""" # pylint: disable=unused-argument if self._conn.is_server() or not self._p: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kex reply msg') host_key_data = packet.get_string() self._f = packet.get_mpint() sig = packet.get_string() packet.check_end() host_key = self._conn.validate_server_host_key(host_key_data) self._verify_reply(host_key, host_key_data, sig) def start(self): """Start DH key exchange""" self._perform_init() class _KexDH(_KexDHBase): """Handler for Diffie-Hellman key exchange""" _init_type = MSG_KEXDH_INIT _reply_type = MSG_KEXDH_REPLY def __init__(self, alg, conn, hash_alg, g, p): super().__init__(alg, conn, hash_alg) self._init_group(g, p) packet_handlers = { MSG_KEXDH_INIT: _KexDHBase._process_init, MSG_KEXDH_REPLY: _KexDHBase._process_reply } class _KexDHGex(_KexDHBase): """Handler for Diffie-Hellman group exchange""" _init_type = MSG_KEX_DH_GEX_INIT _reply_type = MSG_KEX_DH_GEX_REPLY _request_type = MSG_KEX_DH_GEX_REQUEST _group_type = MSG_KEX_DH_GEX_GROUP def __init__(self, alg, conn, hash_alg, preferred_size=None, max_size=None): super().__init__(alg, conn, hash_alg) self._pref_size = preferred_size self._max_size = max_size def _send_request(self): """Send a DH gex request message""" if self._pref_size and not self._max_size: # Send old request message for unit test request = (Byte(MSG_KEX_DH_GEX_REQUEST_OLD) + UInt32(self._pref_size)) else: request = (Byte(self._request_type) + UInt32(KEX_DH_GEX_MIN_SIZE) + UInt32(self._pref_size or KEX_DH_GEX_PREFERRED_SIZE) + UInt32(self._max_size or KEX_DH_GEX_MAX_SIZE)) self._gex_data = request[1:] self._conn.send_packet(request) def _process_request(self, pkttype, packet): """Process a DH gex request message""" if self._conn.is_client(): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kex request msg') self._gex_data = packet.get_remaining_payload() if pkttype == MSG_KEX_DH_GEX_REQUEST_OLD: preferred_size = packet.get_uint32() max_size = KEX_DH_GEX_MAX_SIZE else: _ = packet.get_uint32() preferred_size = packet.get_uint32() max_size = packet.get_uint32() packet.check_end() g, p = _group1_g, _group1_p for gex_size, gex_g, gex_p in _dh_gex_groups: if gex_size > max_size: break else: g, p = gex_g, gex_p if gex_size >= preferred_size: break self._init_group(g, p) self._gex_data += MPInt(p) + MPInt(g) self._conn.send_packet(Byte(self._group_type), MPInt(p), MPInt(g)) def _process_group(self, pkttype, packet): """Process a DH gex group message""" # pylint: disable=unused-argument if self._conn.is_server(): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kex group msg') p = packet.get_mpint() g = packet.get_mpint() packet.check_end() self._init_group(g, p) self._gex_data += MPInt(p) + MPInt(g) self._perform_init() def start(self): """Start DH group exchange""" self._send_request() packet_handlers = { MSG_KEX_DH_GEX_REQUEST_OLD: _process_request, MSG_KEX_DH_GEX_GROUP: _process_group, MSG_KEX_DH_GEX_INIT: _KexDHBase._process_init, MSG_KEX_DH_GEX_REPLY: _KexDHBase._process_reply, MSG_KEX_DH_GEX_REQUEST: _process_request } class _KexGSSBase(_KexDHBase): """Handler for GSS key exchange""" def __init__(self, alg, conn, hash_alg, *args): super().__init__(alg, conn, hash_alg, *args) self._gss = conn.get_gss_context() self._token = None self._host_key_data = b'' self._got_error = False def _check_secure(self): """Check that GSS context is secure enough for key exchange""" if (not self._gss.provides_mutual_auth or not self._gss.provides_integrity): raise DisconnectError(DISC_PROTOCOL_ERROR, 'GSS context not secure') def _send_init(self): """Send a GSS init message""" if not self._token: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Empty GSS token in init') self._conn.send_packet(Byte(MSG_KEXGSS_INIT), String(self._token), MPInt(self._e)) def _send_reply(self, key_data, sig): """Send a GSS reply message""" if self._token: token_data = Boolean(True) + String(self._token) else: token_data = Boolean(False) self._conn.send_packet(Byte(MSG_KEXGSS_COMPLETE), MPInt(self._f), String(sig), token_data) def _send_continue(self): """Send a GSS continue message""" if not self._token: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Empty GSS token in continue') self._conn.send_packet(Byte(MSG_KEXGSS_CONTINUE), String(self._token)) def _process_token(self, token=None): """Process a GSS token""" try: self._token = self._gss.step(token) except GSSError as exc: if self._conn.is_server(): self._conn.send_packet(Byte(MSG_KEXGSS_ERROR), UInt32(exc.maj_code), UInt32(exc.min_code), String(str(exc)), String(DEFAULT_LANG)) if exc.token: self._conn.send_packet(Byte(MSG_KEXGSS_CONTINUE), String(exc.token)) raise DisconnectError(DISC_KEY_EXCHANGE_FAILED, str(exc)) def _process_init(self, pkttype, packet): """Process a GSS init message""" # pylint: disable=unused-argument if self._conn.is_client() or not self._p: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kexgss init msg') token = packet.get_string() self._e = packet.get_mpint() packet.check_end() host_key = self._conn.get_server_host_key() if host_key: self._host_key_data = host_key.public_data self._conn.send_packet(Byte(MSG_KEXGSS_HOSTKEY), String(self._host_key_data)) else: self._host_key_data = b'' self._process_token(token) if self._gss.complete: self._check_secure() self._perform_reply(self._gss, self._host_key_data) self._conn.enable_gss_kex_auth() else: self._send_continue() def _process_continue(self, pkttype, packet): """Process a GSS continue message""" # pylint: disable=unused-argument token = packet.get_string() packet.check_end() if self._conn.is_client() and self._gss.complete: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kexgss continue msg') self._process_token(token) if self._conn.is_server() and self._gss.complete: self._check_secure() self._perform_reply(self._gss, self._host_key_data) else: self._send_continue() def _process_complete(self, pkttype, packet): """Process a GSS complete message""" # pylint: disable=unused-argument if self._conn.is_server(): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kexgss complete msg') self._f = packet.get_mpint() mic = packet.get_string() token_present = packet.get_boolean() token = packet.get_string() if token_present else None packet.check_end() if token: if self._gss.complete: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Non-empty token after complete') self._process_token(token) if self._token: raise DisconnectError(DISC_PROTOCOL_ERROR, 'Non-empty token after complete') if not self._gss.complete: raise DisconnectError(DISC_PROTOCOL_ERROR, 'GSS exchange failed to complete') self._check_secure() self._verify_reply(self._gss, self._host_key_data, mic) self._conn.enable_gss_kex_auth() def _process_hostkey(self, pkttype, packet): """Process a GSS hostkey message""" # pylint: disable=unused-argument self._host_key_data = packet.get_string() packet.check_end() def _process_error(self, pkttype, packet): """Process a GSS error message""" # pylint: disable=unused-argument if self._conn.is_server(): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kexgss error msg') _ = packet.get_uint32() # major_status _ = packet.get_uint32() # minor_status msg = packet.get_string() _ = packet.get_string() # lang packet.check_end() logger.warning('GSS error: %s', msg.decode('utf-8', errors='ignore')) self._got_error = True def start(self): """Start GSS key or group exchange""" self._process_token() super().start() class _KexGSS(_KexGSSBase, _KexDH): """Handler for GSS key exchange""" packet_handlers = { MSG_KEXGSS_INIT: _KexGSSBase._process_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, MSG_KEXGSS_ERROR: _KexGSSBase._process_error } class _KexGSSGex(_KexGSSBase, _KexDHGex): """Handler for GSS group exchange""" _request_type = MSG_KEXGSS_GROUPREQ _group_type = MSG_KEXGSS_GROUP packet_handlers = { MSG_KEXGSS_INIT: _KexGSSBase._process_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, MSG_KEXGSS_ERROR: _KexGSSBase._process_error, MSG_KEXGSS_GROUPREQ: _KexDHGex._process_request, MSG_KEXGSS_GROUP: _KexDHGex._process_group } # pylint: disable=bad-whitespace for _name, _hash_alg in ((b'sha256', sha256), (b'sha1', sha1)): register_kex_alg(b'diffie-hellman-group-exchange-' + _name, _KexDHGex, _hash_alg) register_gss_kex_alg(b'gss-gex-' + _name, _KexGSSGex, _hash_alg) for _name, _hash_alg, _g, _p in ( (b'group1-sha1', sha1, _group1_g, _group1_p), (b'group14-sha1', sha1, _group14_g, _group14_p), (b'group14-sha256', sha256, _group14_g, _group14_p), (b'group15-sha512', sha512, _group15_g, _group15_p), (b'group16-sha512', sha512, _group16_g, _group16_p), (b'group17-sha512', sha512, _group17_g, _group17_p), (b'group18-sha512', sha512, _group18_g, _group18_p)): register_kex_alg(b'diffie-hellman-' + _name, _KexDH, _hash_alg, _g, _p) register_gss_kex_alg(b'gss-' + _name, _KexGSS, _hash_alg, _g, _p) asyncssh-1.11.1/asyncssh/dsa.py000066400000000000000000000163421320320510200163740ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """DSA public key encryption handler""" from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode from .crypto import DSAPrivateKey, DSAPublicKey from .misc import all_ints from .packet import MPInt from .public_key import SSHKey, SSHOpenSSHCertificateV01, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_x509_certificate_alg # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _DSAKey(SSHKey): """Handler for DSA public key encryption""" algorithm = b'ssh-dss' pem_name = b'DSA' pkcs8_oid = ObjectIdentifier('1.2.840.10040.4.1') sig_algorithms = (algorithm,) x509_algorithms = (b'x509v3-' + algorithm,) all_sig_algorithms = set(sig_algorithms) def __eq__(self, other): # This isn't protected access - both objects are _DSAKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.p == other._key.p and self._key.q == other._key.q and self._key.g == other._key.g and self._key.y == other._key.y and self._key.x == other._key.x) def __hash__(self): return hash((self._key.p, self._key.q, self._key.g, self._key.y, self._key.x)) @classmethod def generate(cls, algorithm): """Generate a new DSA private key""" # pylint: disable=unused-argument return cls(DSAPrivateKey.generate(key_size=1024)) @classmethod def make_private(cls, *args): """Construct a DSA private key""" return cls(DSAPrivateKey.construct(*args)) @classmethod def make_public(cls, *args): """Construct a DSA public key""" return cls(DSAPublicKey.construct(*args)) @classmethod def decode_pkcs1_private(cls, key_data): """Decode a PKCS#1 format DSA private key""" if (isinstance(key_data, tuple) and len(key_data) == 6 and all_ints(key_data) and key_data[0] == 0): return key_data[1:] else: return None @classmethod def decode_pkcs1_public(cls, key_data): """Decode a PKCS#1 format DSA public key""" if (isinstance(key_data, tuple) and len(key_data) == 4 and all_ints(key_data)): y, p, q, g = key_data return p, q, g, y else: return None @classmethod def decode_pkcs8_private(cls, alg_params, data): """Decode a PKCS#8 format DSA private key""" try: x = der_decode(data) except ASN1DecodeError: return None if (isinstance(alg_params, tuple) and len(alg_params) == 3 and all_ints(alg_params) and isinstance(x, int)): p, q, g = alg_params y = pow(g, x, p) return p, q, g, y, x else: return None @classmethod def decode_pkcs8_public(cls, alg_params, data): """Decode a PKCS#8 format DSA public key""" try: y = der_decode(data) except ASN1DecodeError: return None if (isinstance(alg_params, tuple) and len(alg_params) == 3 and all_ints(alg_params) and isinstance(y, int)): p, q, g = alg_params return p, q, g, y else: return None @classmethod def decode_ssh_private(cls, packet): """Decode an SSH format DSA private key""" p = packet.get_mpint() q = packet.get_mpint() g = packet.get_mpint() y = packet.get_mpint() x = packet.get_mpint() return p, q, g, y, x @classmethod def decode_ssh_public(cls, packet): """Decode an SSH format DSA public key""" p = packet.get_mpint() q = packet.get_mpint() g = packet.get_mpint() y = packet.get_mpint() return p, q, g, y def encode_pkcs1_private(self): """Encode a PKCS#1 format DSA private key""" if not self._key.x: raise KeyExportError('Key is not private') return (0, self._key.p, self._key.q, self._key.g, self._key.y, self._key.x) def encode_pkcs1_public(self): """Encode a PKCS#1 format DSA public key""" return (self._key.y, self._key.p, self._key.q, self._key.g) def encode_pkcs8_private(self): """Encode a PKCS#8 format DSA private key""" if not self._key.x: raise KeyExportError('Key is not private') return (self._key.p, self._key.q, self._key.g), der_encode(self._key.x) def encode_pkcs8_public(self): """Encode a PKCS#8 format DSA public key""" return (self._key.p, self._key.q, self._key.g), der_encode(self._key.y) def encode_ssh_private(self): """Encode an SSH format DSA private key""" if not self._key.x: raise KeyExportError('Key is not private') return b''.join((MPInt(self._key.p), MPInt(self._key.q), MPInt(self._key.g), MPInt(self._key.y), MPInt(self._key.x))) def encode_ssh_public(self): """Encode an SSH format DSA public key""" return b''.join((MPInt(self._key.p), MPInt(self._key.q), MPInt(self._key.g), MPInt(self._key.y))) def encode_agent_cert_private(self): """Encode DSA certificate private key data for agent""" if not self._key.x: raise KeyExportError('Key is not private') return MPInt(self._key.x) def sign_der(self, data, sig_algorithm): """Compute a DER-encoded signature of the specified data""" # pylint: disable=unused-argument if not self._key.x: raise ValueError('Private key needed for signing') return self._key.sign(data) def verify_der(self, data, sig_algorithm, sig): """Verify a DER-encoded signature of the specified data""" # pylint: disable=unused-argument return self._key.verify(data, sig) def sign_ssh(self, data, sig_algorithm): """Compute an SSH-encoded signature of the specified data""" r, s = der_decode(self.sign_der(data, sig_algorithm)) return r.to_bytes(20, 'big') + s.to_bytes(20, 'big') def verify_ssh(self, data, sig_algorithm, sig): """Verify an SSH-encoded signature of the specified data""" if len(sig) != 40: return False r = int.from_bytes(sig[:20], 'big') s = int.from_bytes(sig[20:], 'big') return self.verify_der(data, sig_algorithm, der_encode((r, s))) register_public_key_alg(b'ssh-dss', _DSAKey) register_certificate_alg(1, b'ssh-dss', b'ssh-dss-cert-v01@openssh.com', _DSAKey, SSHOpenSSHCertificateV01) for alg in _DSAKey.x509_algorithms: register_x509_certificate_alg(alg) asyncssh-1.11.1/asyncssh/ecdh.py000066400000000000000000000110651320320510200165250ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Elliptic curve Diffie-Hellman key exchange handler""" from hashlib import sha256, sha384, sha512 from .constants import DISC_KEY_EXCHANGE_FAILED, DISC_PROTOCOL_ERROR from .kex import Kex, register_kex_alg from .misc import DisconnectError from .packet import Byte, MPInt, String # pylint: disable=bad-whitespace # SSH KEX ECDH message values MSG_KEX_ECDH_INIT = 30 MSG_KEX_ECDH_REPLY = 31 # pylint: enable=bad-whitespace class _KexECDH(Kex): """Handler for elliptic curve Diffie-Hellman key exchange""" def __init__(self, alg, conn, hash_alg, ecdh_class, *args): super().__init__(alg, conn, hash_alg) self._priv = ecdh_class(*args) pub = self._priv.get_public() if conn.is_client(): self._client_pub = pub else: self._server_pub = pub def start(self): """Start ECDH key exchange""" self._conn.send_packet(Byte(MSG_KEX_ECDH_INIT), String(self._client_pub)) def _compute_hash(self, host_key_data, k): """Compute a hash of key information associated with the connection""" hash_obj = self._hash_alg() hash_obj.update(self._conn.get_hash_prefix()) hash_obj.update(String(host_key_data)) hash_obj.update(String(self._client_pub)) hash_obj.update(String(self._server_pub)) hash_obj.update(MPInt(k)) return hash_obj.digest() def _process_init(self, pkttype, packet): """Process an ECDH init message""" # pylint: disable=unused-argument if self._conn.is_client(): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kex init msg') self._client_pub = packet.get_string() packet.check_end() try: k = self._priv.get_shared(self._client_pub) except (AssertionError, ValueError): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid kex init msg') from None host_key = self._conn.get_server_host_key() h = self._compute_hash(host_key.public_data, k) sig = host_key.sign(h) self._conn.send_packet(Byte(MSG_KEX_ECDH_REPLY), String(host_key.public_data), String(self._server_pub), String(sig)) self._conn.send_newkeys(k, h) def _process_reply(self, pkttype, packet): """Process an ECDH reply message""" # pylint: disable=unused-argument if self._conn.is_server(): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unexpected kex reply msg') host_key_data = packet.get_string() self._server_pub = packet.get_string() sig = packet.get_string() packet.check_end() try: k = self._priv.get_shared(self._server_pub) except (AssertionError, ValueError): raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid kex reply msg') from None host_key = self._conn.validate_server_host_key(host_key_data) h = self._compute_hash(host_key_data, k) if not host_key.verify(h, sig): raise DisconnectError(DISC_KEY_EXCHANGE_FAILED, 'Key exchange hash mismatch') self._conn.send_newkeys(k, h) packet_handlers = { MSG_KEX_ECDH_INIT: _process_init, MSG_KEX_ECDH_REPLY: _process_reply } try: # pylint: disable=wrong-import-position from .crypto import Curve25519DH except ImportError: # pragma: no cover pass else: register_kex_alg(b'curve25519-sha256', _KexECDH, sha256, Curve25519DH) register_kex_alg(b'curve25519-sha256@libssh.org', _KexECDH, sha256, Curve25519DH) try: # pylint: disable=wrong-import-position from .crypto import ECDH except ImportError: # pragma: no cover pass else: for _curve_id, _hash_alg in ((b'nistp521', sha512), (b'nistp384', sha384), (b'nistp256', sha256)): register_kex_alg(b'ecdh-sha2-' + _curve_id, _KexECDH, _hash_alg, ECDH, _curve_id) asyncssh-1.11.1/asyncssh/ecdsa.py000066400000000000000000000250011320320510200166740ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """ECDSA public key encryption handler""" from .asn1 import ASN1DecodeError, BitString, ObjectIdentifier, TaggedDERObject from .asn1 import der_encode, der_decode from .crypto import lookup_ec_curve_by_params from .crypto import ECDSAPrivateKey, ECDSAPublicKey from .packet import MPInt, String, SSHPacket from .public_key import SSHKey, SSHOpenSSHCertificateV01 from .public_key import KeyImportError, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_x509_certificate_alg # OID for EC prime fields PRIME_FIELD = ObjectIdentifier('1.2.840.10045.1.1') # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name _alg_oids = {} _alg_oid_map = {} class _ECKey(SSHKey): """Handler for elliptic curve public key encryption""" pem_name = b'EC' pkcs8_oid = ObjectIdentifier('1.2.840.10045.2.1') def __init__(self, key): super().__init__(key) self.algorithm = b'ecdsa-sha2-' + key.curve_id self.sig_algorithms = (self.algorithm,) self.x509_algorithms = (b'x509v3-' + self.algorithm,) self.all_sig_algorithms = set(self.sig_algorithms) self._alg_oid = _alg_oids[key.curve_id] def __eq__(self, other): # This isn't protected access - both objects are _ECKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.curve_id == other._key.curve_id and self._key.x == other._key.x and self._key.y == other._key.y and self._key.d == other._key.d) def __hash__(self): return hash((self._key.curve_id, self._key.x, self._key.y, self._key.d)) @classmethod def _lookup_curve(cls, alg_params): """Look up an EC curve matching the specified parameters""" if isinstance(alg_params, ObjectIdentifier): try: curve_id = _alg_oid_map[alg_params] except KeyError: raise KeyImportError('Unknown elliptic curve OID %s', alg_params) from None elif (isinstance(alg_params, tuple) and len(alg_params) >= 5 and alg_params[0] == 1 and isinstance(alg_params[1], tuple) and len(alg_params[1]) == 2 and alg_params[1][0] == PRIME_FIELD and isinstance(alg_params[2], tuple) and len(alg_params[2]) >= 2 and isinstance(alg_params[3], bytes) and isinstance(alg_params[2][0], bytes) and isinstance(alg_params[2][1], bytes) and isinstance(alg_params[4], int)): p = alg_params[1][1] a = int.from_bytes(alg_params[2][0], 'big') b = int.from_bytes(alg_params[2][1], 'big') point = alg_params[3] n = alg_params[4] try: curve_id = lookup_ec_curve_by_params(p, a, b, point, n) except ValueError as exc: raise KeyImportError(str(exc)) from None else: raise KeyImportError('Invalid EC curve parameters') return curve_id @classmethod def generate(cls, algorithm): """Generate a new EC private key""" # Strip 'ecdsa-sha2-' prefix of algorithm to get curve_id return cls(ECDSAPrivateKey.generate(algorithm[11:])) @classmethod def make_private(cls, curve_id, private_key, public_key): """Construct an EC private key""" if isinstance(private_key, bytes): private_key = int.from_bytes(private_key, 'big') return cls(ECDSAPrivateKey.construct(curve_id, public_key, private_key)) @classmethod def make_public(cls, curve_id, public_key): """Construct an EC public key""" return cls(ECDSAPublicKey.construct(curve_id, public_key)) @classmethod def decode_pkcs1_private(cls, key_data): """Decode a PKCS#1 format EC private key""" if (isinstance(key_data, tuple) and len(key_data) > 2 and key_data[0] == 1 and isinstance(key_data[1], bytes) and isinstance(key_data[2], TaggedDERObject) and key_data[2].tag == 0): alg_params = key_data[2].value private_key = key_data[1] if (len(key_data) > 3 and isinstance(key_data[3], TaggedDERObject) and key_data[3].tag == 1 and isinstance(key_data[3].value, BitString) and key_data[3].value.unused == 0): public_key = key_data[3].value.value else: public_key = None return cls._lookup_curve(alg_params), private_key, public_key else: return None @classmethod def decode_pkcs1_public(cls, key_data): """Decode a PKCS#1 format EC public key""" # pylint: disable=unused-argument raise KeyImportError('PKCS#1 not supported for EC public keys') @classmethod def decode_pkcs8_private(cls, alg_params, data): """Decode a PKCS#8 format EC private key""" try: key_data = der_decode(data) except ASN1DecodeError: key_data = None if (isinstance(key_data, tuple) and len(key_data) > 1 and key_data[0] == 1 and isinstance(key_data[1], bytes)): private_key = key_data[1] if (len(key_data) > 2 and isinstance(key_data[2], TaggedDERObject) and key_data[2].tag == 1 and isinstance(key_data[2].value, BitString) and key_data[2].value.unused == 0): public_key = key_data[2].value.value else: public_key = None return cls._lookup_curve(alg_params), private_key, public_key else: return None @classmethod def decode_pkcs8_public(cls, alg_params, key_data): """Decode a PKCS#8 format EC public key""" if isinstance(alg_params, ObjectIdentifier): return cls._lookup_curve(alg_params), key_data else: return None @classmethod def decode_ssh_private(cls, packet): """Decode an SSH format EC private key""" curve_id = packet.get_string() public_key = packet.get_string() private_key = packet.get_mpint() return curve_id, private_key, public_key @classmethod def decode_ssh_public(cls, packet): """Decode an SSH format EC public key""" curve_id = packet.get_string() public_key = packet.get_string() return curve_id, public_key def encode_public_tagged(self): """Encode an EC public key blob as a tagged bitstring""" return TaggedDERObject(1, BitString(self._key.public_value)) def encode_pkcs1_private(self): """Encode a PKCS#1 format EC private key""" if not self._key.private_value: raise KeyExportError('Key is not private') return (1, self._key.private_value, TaggedDERObject(0, self._alg_oid), self.encode_public_tagged()) def encode_pkcs1_public(self): """Encode a PKCS#1 format EC public key""" raise KeyExportError('PKCS#1 is not supported for EC public keys') def encode_pkcs8_private(self): """Encode a PKCS#8 format EC private key""" if not self._key.private_value: raise KeyExportError('Key is not private') return self._alg_oid, der_encode((1, self._key.private_value, self.encode_public_tagged())) def encode_pkcs8_public(self): """Encode a PKCS#8 format EC public key""" return self._alg_oid, self._key.public_value def encode_ssh_private(self): """Encode an SSH format EC private key""" if not self._key.d: raise KeyExportError('Key is not private') return b''.join((String(self._key.curve_id), String(self._key.public_value), MPInt(self._key.d))) def encode_ssh_public(self): """Encode an SSH format EC public key""" return b''.join((String(self._key.curve_id), String(self._key.public_value))) def encode_agent_cert_private(self): """Encode ECDSA certificate private key data for agent""" if not self._key.d: raise KeyExportError('Key is not private') return MPInt(self._key.d) def sign_der(self, data, sig_algorithm): """Compute a DER-encoded signature of the specified data""" # pylint: disable=unused-argument if not self._key.private_value: raise ValueError('Private key needed for signing') return self._key.sign(data) def verify_der(self, data, sig_algorithm, sig): """Verify a DER-encoded signature of the specified data""" # pylint: disable=unused-argument return self._key.verify(data, sig) def sign_ssh(self, data, sig_algorithm): """Compute an SSH-encoded signature of the specified data""" r, s = der_decode(self.sign_der(data, sig_algorithm)) return MPInt(r) + MPInt(s) def verify_ssh(self, data, sig_algorithm, sig): """Verify an SSH-encoded signature of the specified data""" packet = SSHPacket(sig) r = packet.get_mpint() s = packet.get_mpint() packet.check_end() return self.verify_der(data, sig_algorithm, der_encode((r, s))) for _curve_id, _oid in ((b'nistp521', '1.3.132.0.35'), (b'nistp384', '1.3.132.0.34'), (b'nistp256', '1.2.840.10045.3.1.7')): _algorithm = b'ecdsa-sha2-' + _curve_id _cert_algorithm = _algorithm + b'-cert-v01@openssh.com' _x509_algorithm = b'x509v3-' + _algorithm _oid = ObjectIdentifier(_oid) _alg_oids[_curve_id] = _oid _alg_oid_map[_oid] = _curve_id register_public_key_alg(_algorithm, _ECKey, (_algorithm,)) register_certificate_alg(1, _algorithm, _cert_algorithm, _ECKey, SSHOpenSSHCertificateV01) register_x509_certificate_alg(_x509_algorithm) asyncssh-1.11.1/asyncssh/ed25519.py000066400000000000000000000077151320320510200166270ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Ed25519 public key encryption handler""" from .packet import String from .public_key import SSHKey, SSHOpenSSHCertificateV01, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _Ed25519Key(SSHKey): """Handler for Ed25519 public key encryption""" algorithm = b'ssh-ed25519' sig_algorithms = (algorithm,) all_sig_algorithms = set(sig_algorithms) def __init__(self, vk, sk): super().__init__() self._vk = vk self._sk = sk def __eq__(self, other): # This isn't protected access - both objects are _Ed25519Key instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._vk == other._vk and self._sk == other._sk) def __hash__(self): return hash(self._vk) @classmethod def generate(cls, algorithm): """Generate a new Ed25519 private key""" # pylint: disable=unused-argument return cls(*libnacl.crypto_sign_keypair()) @classmethod def make_private(cls, vk, sk): """Construct an Ed25519 private key""" return cls(vk, sk) @classmethod def make_public(cls, vk): """Construct an Ed25519 public key""" return cls(vk, None) @classmethod def decode_ssh_private(cls, packet): """Decode an SSH format Ed25519 private key""" vk = packet.get_string() sk = packet.get_string() return vk, sk @classmethod def decode_ssh_public(cls, packet): """Decode an SSH format Ed25519 public key""" vk = packet.get_string() return (vk,) def encode_ssh_private(self): """Encode an SSH format Ed25519 private key""" if self._sk is None: raise KeyExportError('Key is not private') return b''.join((String(self._vk), String(self._sk))) def encode_ssh_public(self): """Encode an SSH format Ed25519 public key""" return String(self._vk) def encode_agent_cert_private(self): """Encode Ed25519 certificate private key data for agent""" return self.encode_ssh_private() def sign_der(self, data, sig_algorithm): """Return a DER-encoded signature of the specified data""" # pylint: disable=unused-argument if self._sk is None: raise ValueError('Private key needed for signing') sig = libnacl.crypto_sign(data, self._sk) return sig[:-len(data)] def verify_der(self, data, sig_algorithm, sig): """Verify a DER-encoded signature of the specified data""" # pylint: disable=unused-argument try: return libnacl.crypto_sign_open(sig + data, self._vk) == data except ValueError: return False def sign_ssh(self, data, sig_algorithm): """Return an SSH-encoded signature of the specified data""" return self.sign_der(data, sig_algorithm) def verify_ssh(self, data, sig_algorithm, sig): """Verify an SSH-encoded signature of the specified data""" return self.verify_der(data, sig_algorithm, sig) try: # pylint: disable=wrong-import-position,wrong-import-order import libnacl except (ImportError, OSError): # pragma: no cover pass else: register_public_key_alg(b'ssh-ed25519', _Ed25519Key) register_certificate_alg(1, b'ssh-ed25519', b'ssh-ed25519-cert-v01@openssh.com', _Ed25519Key, SSHOpenSSHCertificateV01) asyncssh-1.11.1/asyncssh/editor.py000066400000000000000000000420251320320510200171100ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Input line editor""" from unicodedata import east_asian_width _DEFAULT_WIDTH = 80 _ansi_terminals = ('ansi', 'cygwin', 'linux', 'putty', 'screen', 'teraterm', 'cit80', 'vt100', 'vt102', 'vt220', 'vt320', 'xterm', 'xterm-color', 'xterm-16color', 'xterm-256color', 'rxvt', 'rxvt-color') def _is_wide(ch): """Return display width of character""" return east_asian_width(ch) in 'WF' class SSHLineEditor: """Input line editor""" def __init__(self, chan, session, history_size, term_type, width): self._chan = chan self._session = session self._history_size = history_size if history_size > 0 else 0 self._wrap = term_type in _ansi_terminals self._width = width or _DEFAULT_WIDTH self._line_mode = True self._echo = True self._start_column = 0 self._end_column = 0 self._cursor = 0 self._left_pos = 0 self._right_pos = 0 self._pos = 0 self._line = '' self._key_state = self._keymap self._erased = '' self._history = [] self._history_index = 0 @classmethod def build_keymap(cls): """Build keyboard input map""" cls._keymap = {} for func, keys in cls._keylist: for key in keys: keymap = cls._keymap for ch in key[:-1]: if ch not in keymap: keymap[ch] = {} keymap = keymap[ch] keymap[key[-1]] = func def _determine_column(self, data, start): """Determine new output column after output occurs""" column = start for ch in data: if ch == '\b': column -= 1 else: if _is_wide(ch) and (column % self._width) == self._width - 1: column += 1 column += 2 if _is_wide(ch) else 1 return column def _column(self, pos): """Determine output column of end of current input line""" return self._determine_column(self._line[self._left_pos:pos], self._start_column) def _output(self, data): """Generate output and calculate new output column""" self._chan.write(data.replace('\n', '\r\n')) idx = data.rfind('\n') if idx >= 0: tail = data[idx+1:] self._cursor = 0 else: tail = data self._cursor = self._determine_column(tail, self._cursor) if (self._line_mode and self._cursor and self._cursor % self._width == 0): self._chan.write(' \b') def _ring_bell(self): """Ring the terminal bell""" self._chan.write('\a') def _update_input_window(self, pos): """Update visible input window when not wrapping onto multiple lines""" if pos < self._left_pos: self._left_pos = pos else: if pos < len(self._line): pos += 1 if self._column(pos) >= self._width: while self._column(pos) >= self._width: self._left_pos += 1 else: while self._left_pos > 0: self._left_pos -= 1 if self._column(pos) >= self._width: self._left_pos += 1 break column = self._start_column self._right_pos = self._left_pos while self._right_pos < len(self._line): column += 1 + _is_wide(self._line[self._right_pos]) if column < self._width: self._right_pos += 1 else: break def _update_line(self, start_pos=None, end_pos=None): """Update display of selected portion of input line""" self._output(self._line[start_pos:end_pos]) if self._end_column > self._cursor: new_end_column = self._cursor self._output(' ' * (self._end_column - new_end_column)) self._end_column = new_end_column else: self._end_column = self._cursor def _move_cursor(self, pos): """Move the cursor to selected position in input line""" if self._wrap: new_column = self._column(pos) start_row = self._cursor // self._width start_col = self._cursor % self._width end_row = new_column // self._width end_col = new_column % self._width if end_row < start_row: self._chan.write('\x1b[' + str(start_row-end_row) + 'A') elif end_row > start_row: self._chan.write('\x1b[' + str(end_row-start_row) + 'B') if end_col > start_col: self._chan.write('\x1b[' + str(end_col-start_col) + 'C') elif end_col < start_col: self._chan.write('\x1b[' + str(start_col-end_col) + 'D') self._cursor = new_column else: self._update_input_window(pos) self._output('\b' * (self._cursor - self._start_column)) self._update_line(self._left_pos, self._right_pos) self._output('\b' * (self._cursor - self._column(pos))) def _reposition(self, new_pos): """Reposition the cursor to selected position in input""" if self._line_mode and self._echo: self._move_cursor(new_pos) self._pos = new_pos def _erase_input(self): """Erase current input line""" if self._start_column != self._end_column: self._move_cursor(0) self._output(' ' * (self._end_column - self._cursor)) self._move_cursor(0) self._end_column = self._start_column def _draw_input(self): """Draw current input line""" if (self._line_mode and self._echo and self._line and self._start_column == self._end_column): if self._wrap: self._update_line() else: self._update_input_window(self._pos) self._update_line(self._left_pos, self._right_pos) self._move_cursor(self._pos) def _update_input(self, start_pos, new_pos): """Update selected portion of current input line""" if self._line_mode and self._echo: self._move_cursor(start_pos) if self._wrap: self._update_line(start_pos) self._reposition(new_pos) def _insert_printable(self, data): """Insert data into the input line""" data_len = len(data) self._line = self._line[:self._pos] + data + self._line[self._pos:] self._pos += data_len if self._line_mode and self._echo: self._update_input(self._pos - data_len, self._pos) def _end_line(self): """End the current input line and send it to the session""" if (self._echo and not self._wrap and (self._left_pos > 0 or self._right_pos < len(self._line))): self._output('\b' * (self._cursor - self._start_column) + self._line) else: self._reposition(len(self._line)) self._output('\n') self._start_column = 0 self._end_column = 0 self._cursor = 0 self._left_pos = 0 self._right_pos = 0 self._pos = 0 if self._echo and self._history_size and self._line: self._history.append(self._line) self._history = self._history[-self._history_size:] self._history_index = len(self._history) data = self._line + '\n' self._line = '' self._session.data_received(data, None) def _eof_or_delete(self): """Erase character to the right, or send EOF if input line is empty""" if not self._line: self._session.eof_received() else: self._erase_right() def _erase_left(self): """Erase character to the left""" if self._pos > 0: self._line = self._line[:self._pos-1] + self._line[self._pos:] self._update_input(self._pos - 1, self._pos - 1) else: self._ring_bell() def _erase_right(self): """Erase character to the right""" if self._pos < len(self._line): self._line = self._line[:self._pos] + self._line[self._pos+1:] self._update_input(self._pos, self._pos) else: self._ring_bell() def _erase_line(self): """Erase entire input line""" self._erased = self._line self._line = '' self._update_input(0, 0) def _erase_to_end(self): """Erase to end of input line""" self._erased = self._line[self._pos:] self._line = self._line[:self._pos] self._update_input(self._pos, self._pos) def _history_prev(self): """Replace input with previous line in history""" if self._history_index > 0: self._history_index -= 1 self._line = self._history[self._history_index] self._update_input(0, len(self._line)) else: self._ring_bell() def _history_next(self): """Replace input with next line in history""" if self._history_index < len(self._history): self._history_index += 1 if self._history_index < len(self._history): self._line = self._history[self._history_index] else: self._line = '' self._update_input(0, len(self._line)) else: self._ring_bell() def _move_left(self): """Move left in input line""" if self._pos > 0: self._reposition(self._pos - 1) else: self._ring_bell() def _move_right(self): """Move right in input line""" if self._pos < len(self._line): self._reposition(self._pos + 1) else: self._ring_bell() def _move_to_start(self): """Move to start of input line""" self._reposition(0) def _move_to_end(self): """Move to end of input line""" self._reposition(len(self._line)) def _redraw(self): """Redraw input line""" self._erase_input() self._draw_input() def _insert_erased(self): """Insert previously erased input""" self._insert_printable(self._erased) def _send_break(self): """Send break to session""" self._session.break_received(0) # pylint: disable=bad-whitespace _keylist = ((_end_line, ('\n', '\r', '\x1bOM')), (_eof_or_delete, ('\x04',)), (_erase_left, ('\x08', '\x7f')), (_erase_right, ('\x1b[3~',)), (_erase_line, ('\x15',)), (_erase_to_end, ('\x0b',)), (_history_prev, ('\x10', '\x1b[A', '\x1bOA')), (_history_next, ('\x0e', '\x1b[B', '\x1bOB')), (_move_left, ('\x02', '\x1b[D', '\x1bOD')), (_move_right, ('\x06', '\x1b[C', '\x1bOC')), (_move_to_start, ('\x01', '\x1b[H', '\x1b[1~')), (_move_to_end, ('\x05', '\x1b[F', '\x1b[4~')), (_redraw, ('\x12',)), (_insert_erased, ('\x19',)), (_send_break, ('\x03', '\x1b[33~'))) # pylint: enable=bad-whitespace def set_line_mode(self, line_mode): """Enable/disable input line editing""" if self._line and not line_mode: data = self._line self._erase_input() self._line = '' self._session.data_received(data, None) self._line_mode = line_mode def set_echo(self, echo): """Enable/disable echoing of input in line mode""" if self._echo and not echo: self._erase_input() self._echo = False elif echo and not self._echo: self._echo = True self._draw_input() def set_width(self, width): """Set terminal line width""" self._width = width or _DEFAULT_WIDTH if self._wrap: self._cursor = self._column(self._pos) self._redraw() def process_input(self, data, datatype): """Process input from channel""" if self._line_mode: for ch in data: if ch in self._key_state: self._key_state = self._key_state[ch] if callable(self._key_state): try: self._key_state(self) finally: self._key_state = self._keymap elif self._key_state == self._keymap and ch.isprintable(): self._insert_printable(ch) else: self._key_state = self._keymap self._ring_bell() else: self._session.data_received(data, datatype) def process_output(self, data): """Process output to channel""" self._erase_input() self._output(data) self._start_column = self._cursor self._end_column = self._cursor self._draw_input() class SSHLineEditorChannel: """Input line editor channel wrapper When creating server channels with ``line_editor`` set to ``True``, this class is wrapped around the channel, providing the caller with the ability to enable and disable input line editing and echoing. .. note:: Line editing is only available when a psuedo-terminal is requested on the server channel and the character encoding on the channel is not set to ``None``. """ def __init__(self, orig_chan, orig_session, history_size): self._orig_chan = orig_chan self._orig_session = orig_session self._history_size = history_size self._editor = None def __getattr__(self, attr): """Delegate most channel functions to original channel""" return getattr(self._orig_chan, attr) def create_editor(self): """Create input line editor if encoding and terminal type are set""" if self._encoding and self._term_type: self._editor = SSHLineEditor(self._orig_chan, self._orig_session, self._history_size, self._term_type, self._term_size[0]) return self._editor def set_line_mode(self, line_mode): """Enable/disable input line editing This method enabled or disables input line editing. When set, only full lines of input are sent to the session, and each line of input can be edited before it is sent. :param bool line_mode: Whether or not to process input a line at a time """ self._editor.set_line_mode(line_mode) def set_echo(self, echo): """Enable/disable echoing of input in line mode This method enables or disables echoing of input data when input line editing is enabled. :param bool echo: Whether or not input to echo input as it is entered """ self._editor.set_echo(echo) def write(self, data, datatype=None): """Process data written to the channel""" if self._editor: self._editor.process_output(data) else: self._orig_chan.write(data, datatype) class SSHLineEditorSession: """Input line editor session wrapper""" def __init__(self, chan, orig_session): self._chan = chan self._orig_session = orig_session self._editor = None def __getattr__(self, attr): """Delegate most channel functions to original session""" return getattr(self._orig_session, attr) def session_started(self): """Start a session for this newly opened server channel""" self._editor = self._chan.create_editor() self._orig_session.session_started() def terminal_size_changed(self, width, height, pixwidth, pixheight): """The terminal size has changed""" if self._editor: self._editor.set_width(width) self._orig_session.terminal_size_changed(width, height, pixwidth, pixheight) def data_received(self, data, datatype): """Process data received from the channel""" if self._editor: self._editor.process_input(data, datatype) else: self._orig_session.data_received(data, datatype) def eof_received(self): """Process EOF received from the channel""" if self._editor: self._editor.set_line_mode(False) return self._orig_session.eof_received() SSHLineEditor.build_keymap() asyncssh-1.11.1/asyncssh/forward.py000066400000000000000000000107621320320510200172710ustar00rootroot00000000000000# Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH port forwarding handlers""" import asyncio import socket from .misc import ChannelOpenError class SSHForwarder: """SSH port forwarding connection handler""" def __init__(self, peer=None): self._peer = peer self._transport = None self._inpbuf = b'' self._eof_received = False if peer: peer.set_peer(self) def set_peer(self, peer): """Set the peer forwarder to exchange data with""" self._peer = peer def write(self, data): """Write data to the transport""" self._transport.write(data) def write_eof(self): """Write end of file to the transport""" self._transport.write_eof() def was_eof_received(self): """Return whether end of file has been received or not""" return self._eof_received def pause_reading(self): """Pause reading from the transport""" self._transport.pause_reading() def resume_reading(self): """Resume reading on the transport""" self._transport.resume_reading() def connection_made(self, transport): """Handle a newly opened connection""" self._transport = transport sock = transport.get_extra_info('socket') if sock.family in {socket.AF_INET, socket.AF_INET6}: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) def connection_lost(self, exc): """Handle an incoming connection close""" # pylint: disable=unused-argument self.close() def session_started(self): """Handle session start""" def data_received(self, data, datatype=None): """Handle incoming data from the transport""" # pylint: disable=unused-argument if self._peer: try: self._peer.write(data) except OSError: # pragma: no cover pass else: self._inpbuf += data def eof_received(self): """Handle an incoming end of file from the transport""" self._eof_received = True if self._peer: try: self._peer.write_eof() except OSError: # pragma: no cover pass return not self._peer.was_eof_received() else: return False def pause_writing(self): """Pause writing by asking peer to pause reading""" self._peer.pause_reading() def resume_writing(self): """Resume writing by asking peer to resume reading""" self._peer.resume_reading() def close(self): """Close this port forwarder""" if self._transport: self._transport.close() self._transport = None if self._peer: peer = self._peer self._peer = None peer.close() class SSHLocalForwarder(SSHForwarder): """Local forwarding connection handler""" def __init__(self, conn, coro): super().__init__() self._conn = conn self._coro = coro @asyncio.coroutine def _forward(self, *args): """Begin local forwarding""" def session_factory(): """Return an SSH forwarder""" return SSHForwarder(self) try: yield from self._coro(session_factory, *args) except ChannelOpenError: self.close() return if self._inpbuf: self.data_received(self._inpbuf) self._inpbuf = b'' class SSHLocalPortForwarder(SSHLocalForwarder): """Local TCP port forwarding connection handler""" def connection_made(self, transport): """Handle a newly opened connection""" super().connection_made(transport) orig_host, orig_port = transport.get_extra_info('peername')[:2] self._conn.create_task(self._forward(orig_host, orig_port)) class SSHLocalPathForwarder(SSHLocalForwarder): """Local UNIX domain socket forwarding connection handler""" def connection_made(self, transport): """Handle a newly opened connection""" super().connection_made(transport) self._conn.create_task(self._forward()) asyncssh-1.11.1/asyncssh/gss.py000066400000000000000000000027321320320510200164170ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """GSSAPI wrapper""" import sys try: # pylint: disable=unused-import if sys.platform == 'win32': # pragma: no cover from .gss_win32 import GSSError, GSSClient, GSSServer else: from .gss_unix import GSSError, GSSClient, GSSServer gss_available = True except ImportError: # pragma: no cover gss_available = False class GSSError(ValueError): """Stub class for reporting that GSS is not available""" def __init__(self, maj_code=0, min_code=0, token=None): super().__init__('GSS not available') self.maj_code = maj_code self.min_code = min_code self.token = token class GSSClient: """Stub client class for reporting that GSS is not available""" def __init__(self, host, delegate_creds): # pylint: disable=unused-argument raise GSSError() class GSSServer: """Stub client class for reporting that GSS is not available""" def __init__(self, host): # pylint: disable=unused-argument raise GSSError() asyncssh-1.11.1/asyncssh/gss_unix.py000066400000000000000000000074101320320510200174600ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """GSSAPI wrapper for UNIX""" from gssapi import Credentials, Name, NameType from gssapi import RequirementFlag, SecurityContext from gssapi.exceptions import GSSError from .asn1 import OBJECT_IDENTIFIER def _mech_to_oid(mech): """Return a DER-encoded OID corresponding to the requested GSS mechanism""" mech = bytes(mech) return bytes((OBJECT_IDENTIFIER, len(mech))) + mech class _GSSBase: """GSS base class""" def __init__(self, host, usage): if '@' in host: self._host = Name(host) else: self._host = Name('host@' + host, NameType.hostbased_service) if usage == 'initiate': self._creds = Credentials(usage=usage) else: self._creds = Credentials(name=self._host, usage=usage) self._mechs = [_mech_to_oid(mech) for mech in self._creds.mechs] self._ctx = None def _init_context(self): """Abstract method to construct GSS security context""" raise NotImplementedError @property def mechs(self): """Return GSS mechanisms available for this host""" return self._mechs @property def complete(self): """Return whether or not GSS negotiation is complete""" return self._ctx and self._ctx.complete @property def provides_mutual_auth(self): """Return whether or not this context provides mutual authentication""" return (RequirementFlag.mutual_authentication in self._ctx.actual_flags) @property def provides_integrity(self): """Return whether or not this context provides integrity protection""" return RequirementFlag.integrity in self._ctx.actual_flags @property def user(self): """Return user principal associated with this context""" return str(self._ctx.initiator_name) @property def host(self): """Return host principal associated with this context""" return str(self._ctx.target_name) def reset(self): """Reset GSS security context""" self._ctx = None def step(self, token=None): """Perform next step in GSS security exchange""" if not self._ctx: self._init_context() return self._ctx.step(token) def sign(self, data): """Sign a block of data""" return self._ctx.get_signature(data) def verify(self, data, sig): """Verify a signature for a block of data""" try: self._ctx.verify_signature(data, sig) return True except GSSError: return False class GSSClient(_GSSBase): """GSS client""" def __init__(self, host, delegate_creds): super().__init__(host, 'initiate') flags = set((RequirementFlag.mutual_authentication, RequirementFlag.integrity)) if delegate_creds: flags.add(RequirementFlag.delegate_to_peer) self._flags = flags def _init_context(self): """Construct GSS client security context""" self._ctx = SecurityContext(name=self._host, creds=self._creds, flags=self._flags) class GSSServer(_GSSBase): """GSS server""" def __init__(self, host): super().__init__(host, 'accept') def _init_context(self): """Construct GSS server security context""" self._ctx = SecurityContext(creds=self._creds) asyncssh-1.11.1/asyncssh/gss_win32.py000066400000000000000000000111271320320510200174370ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """GSSAPI wrapper for Windows""" # Some of the imports below won't be found when running pylint on UNIX # pylint: disable=import-error from sspi import ClientAuth, ServerAuth from sspi import error as SSPIError from sspicon import ISC_REQ_DELEGATE, ISC_REQ_INTEGRITY, ISC_REQ_MUTUAL_AUTH from sspicon import ISC_RET_INTEGRITY, ISC_RET_MUTUAL_AUTH from sspicon import ASC_REQ_INTEGRITY, ASC_REQ_MUTUAL_AUTH from sspicon import ASC_RET_INTEGRITY, ASC_RET_MUTUAL_AUTH from sspicon import SECPKG_ATTR_NATIVE_NAMES from .asn1 import ObjectIdentifier, der_encode _krb5_oid = der_encode(ObjectIdentifier('1.2.840.113554.1.2.2')) class _GSSBase: """GSS base class""" # Overridden in client classes _mutual_auth_flag = 0 _integrity_flag = 0 def __init__(self, host): if '@' in host: self._host = host else: self._host = 'host/' + host self._ctx = None self._init_token = None @property def mechs(self): """Return GSS mechanisms available for this host""" return [_krb5_oid] @property def complete(self): """Return whether or not GSS negotiation is complete""" return self._ctx.authenticated @property def provides_mutual_auth(self): """Return whether or not this context provides mutual authentication""" return self._ctx.ctxt_attr & self._mutual_auth_flag @property def provides_integrity(self): """Return whether or not this context provides integrity protection""" return self._ctx.ctxt_attr & self._integrity_flag @property def user(self): """Return user principal associated with this context""" names = self._ctx.ctxt.QueryContextAttributes(SECPKG_ATTR_NATIVE_NAMES) return names[0] @property def host(self): """Return host principal associated with this context""" names = self._ctx.ctxt.QueryContextAttributes(SECPKG_ATTR_NATIVE_NAMES) return names[1] def reset(self): """Reset GSS security context""" if self._ctx.authenticated: self._ctx.reset() def step(self, token=None): """Perform next step in GSS security exchange""" if self._init_token: token = self._init_token self._init_token = None return token try: _, buf = self._ctx.authorize(token) return buf[0].Buffer except SSPIError as exc: raise GSSError(details=exc.strerror) from None def sign(self, data): """Sign a block of data""" try: return self._ctx.sign(data) except SSPIError as exc: raise GSSError(details=exc.strerror) from None def verify(self, data, sig): """Verify a signature for a block of data""" try: self._ctx.verify(data, sig) return True except SSPIError: return False class GSSClient(_GSSBase): """GSS client""" _mutual_auth_flag = ISC_RET_MUTUAL_AUTH _integrity_flag = ISC_RET_INTEGRITY def __init__(self, host, delegate_creds): super().__init__(host) flags = ISC_REQ_MUTUAL_AUTH | ISC_REQ_INTEGRITY if delegate_creds: flags |= ISC_REQ_DELEGATE try: self._ctx = ClientAuth('Kerberos', targetspn=self._host, scflags=flags) except SSPIError as exc: raise GSSError(1, 1, details=exc.strerror) self._init_token = self.step(None) class GSSServer(_GSSBase): """GSS server""" _mutual_auth_flag = ASC_RET_MUTUAL_AUTH _integrity_flag = ASC_RET_INTEGRITY def __init__(self, host): super().__init__(host) flags = ASC_REQ_MUTUAL_AUTH | ASC_REQ_INTEGRITY try: self._ctx = ServerAuth('Kerberos', spn=self._host, scflags=flags) except SSPIError as exc: raise GSSError(1, 1, details=exc.strerror) class GSSError(Exception): """Stub class for reporting that GSS is not available""" def __init__(self, maj_code=0, min_code=0, token=None, details=''): super().__init__(details) self.maj_code = maj_code self.min_code = min_code self.token = token asyncssh-1.11.1/asyncssh/kex.py000066400000000000000000000051171320320510200164120ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH key exchange handlers""" import binascii from hashlib import md5 from .packet import MPInt, SSHPacketHandler _kex_algs = [] _kex_handlers = {} _gss_kex_algs = [] _gss_kex_handlers = {} class Kex(SSHPacketHandler): """Parent class for key exchange handlers""" def __init__(self, alg, conn, hash_alg): self.algorithm = alg self._conn = conn self._hash_alg = hash_alg def compute_key(self, k, h, x, session_id, keylen): """Compute keys from output of key exchange""" key = b'' while len(key) < keylen: hash_obj = self._hash_alg() hash_obj.update(MPInt(k)) hash_obj.update(h) hash_obj.update(key if key else x + session_id) key += hash_obj.digest() return key[:keylen] def register_kex_alg(alg, handler, hash_alg, *args): """Register a key exchange algorithm""" _kex_algs.append(alg) _kex_handlers[alg] = (handler, hash_alg, args) def register_gss_kex_alg(alg, handler, hash_alg, *args): """Register a GSSAPI key exchange algorithm""" _gss_kex_algs.append(alg) _gss_kex_handlers[alg] = (handler, hash_alg, args) def get_kex_algs(): """Return a list of available key exchange algorithms""" return _gss_kex_algs + _kex_algs def expand_kex_algs(kex_algs, mechs, host_key_available): """Add mechanisms to GSS entries in key exchange algorithm list""" expanded_kex_algs = [] for alg in kex_algs: if alg.startswith(b'gss-'): for mech in mechs: suffix = b'-' + binascii.b2a_base64(md5(mech).digest())[:-1] expanded_kex_algs.append(alg + suffix) elif host_key_available: expanded_kex_algs.append(alg) return expanded_kex_algs def get_kex(conn, alg): """Return a key exchange handler The function looks up a key exchange algorithm and returns a handler which can perform that type of key exchange. """ if alg.startswith(b'gss-'): alg = alg.rsplit(b'-', 1)[0] handler, hash_alg, args = _gss_kex_handlers[alg] else: handler, hash_alg, args = _kex_handlers[alg] return handler(alg, conn, hash_alg, *args) asyncssh-1.11.1/asyncssh/known_hosts.py000066400000000000000000000252341320320510200202010ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # Alexander Travov - proposed changes to add negated patterns, hashed # entries, and support for the revoked marker # Josh Yudaken - proposed change to split parsing and matching to avoid # parsing large known_hosts lists multiple times """Parser for SSH known_hosts files""" import binascii import hmac from hashlib import sha1 try: from .crypto import X509NamePattern _x509_available = True except ImportError: # pragma: no cover _x509_available = False from .misc import ip_address from .pattern import HostPatternList from .public_key import KeyImportError, import_public_key from .public_key import import_certificate, import_certificate_subject from .public_key import load_public_keys, load_certificates def _load_subject_names(names): """Load a list of X.509 subject name patterns""" if not _x509_available: # pragma: no cover return [] return list(map(X509NamePattern, names)) class _PlainHost: """A plain host entry in a known_hosts file""" def __init__(self, pattern): self._pattern = HostPatternList(pattern) def matches(self, host, addr, ip): """Return whether a host or address matches this host pattern list""" return self._pattern.matches(host, addr, ip) class _HashedHost: """A hashed host entry in a known_hosts file""" _HMAC_SHA1_MAGIC = '1' def __init__(self, pattern): try: magic, salt, hosthash = pattern[1:].split('|') self._salt = binascii.a2b_base64(salt) self._hosthash = binascii.a2b_base64(hosthash) except (ValueError, binascii.Error): raise ValueError('Invalid known hosts hash entry: %s' % pattern) from None if magic != self._HMAC_SHA1_MAGIC: # Only support HMAC SHA-1 for now raise ValueError('Invalid known hosts hash type: %s' % magic) from None def _match(self, value): """Return whether this host hash matches a value""" hosthash = hmac.new(self._salt, value.encode(), sha1).digest() return hosthash == self._hosthash def matches(self, host, addr, ip): """Return whether a host or address matches this host hash""" # pylint: disable=unused-argument return (host and self._match(host)) or (addr and self._match(addr)) class SSHKnownHosts: """An SSH known hosts list""" def __init__(self, known_hosts): self._exact_entries = {} self._pattern_entries = [] for line in known_hosts.splitlines(): line = line.strip() if not line or line.startswith('#'): continue try: if line.startswith('@'): marker, pattern, data = line[1:].split(None, 2) else: marker = None pattern, data = line.split(None, 1) except ValueError: raise ValueError('Invalid known hosts entry: %s' % line) from None if marker not in (None, 'cert-authority', 'revoked'): raise ValueError('Invalid known hosts marker: %s' % marker) from None key = None cert = None subject = None try: key = import_public_key(data) except KeyImportError: try: cert = import_certificate(data) except KeyImportError: if not _x509_available: # pragma: no cover continue try: subject = import_certificate_subject(data) except KeyImportError: # Ignore keys in the file that we're unable to parse continue subject = X509NamePattern(subject) if any(c in pattern for c in '*?|/!'): self._add_pattern(marker, pattern, key, cert, subject) else: self._add_exact(marker, pattern, key, cert, subject) def _add_exact(self, marker, pattern, key, cert, subject): """Add an exact match entry""" for entry in pattern.split(','): if entry not in self._exact_entries: self._exact_entries[entry] = [] self._exact_entries[entry].append((marker, key, cert, subject)) def _add_pattern(self, marker, pattern, key, cert, subject): """Add a pattern match entry""" if pattern.startswith('|'): entry = _HashedHost(pattern) else: entry = _PlainHost(pattern) self._pattern_entries.append((entry, (marker, key, cert, subject))) def _match(self, host, addr, port=None): """Find host keys matching specified host, address, and port""" ip = ip_address(addr) if addr else None if port: host = '[{}]:{}'.format(host, port) if host else None addr = '[{}]:{}'.format(addr, port) if addr else None matches = [] matches += self._exact_entries.get(host, []) matches += self._exact_entries.get(addr, []) matches += (match for (entry, match) in self._pattern_entries if entry.matches(host, addr, ip)) host_keys = [] ca_keys = [] revoked_keys = [] x509_certs = [] revoked_certs = [] x509_subjects = [] revoked_subjects = [] for marker, key, cert, subject in matches: if key: if marker == 'revoked': revoked_keys.append(key) elif marker == 'cert-authority': ca_keys.append(key) else: host_keys.append(key) elif cert: if marker == 'revoked': revoked_certs.append(cert) else: x509_certs.append(cert) else: if marker == 'revoked': revoked_subjects.append(subject) else: x509_subjects.append(subject) return (host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, x509_subjects, revoked_subjects) def match(self, host, addr, port): """Match a host, IP address, and port against known_hosts patterns If the port is not the default port and no match is found for it, the lookup is attempted again without a port number. :param str host: The hostname of the target host :param str addr: The IP address of the target host :param int port: The port number on the target host, or ``None`` for the default :returns: A tuple of matching host keys, CA keys, and revoked keys """ host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, \ x509_subjects, revoked_subjects = self._match(host, addr, port) if port and not (host_keys or ca_keys or x509_certs or x509_subjects): host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, \ x509_subjects, revoked_subjects = self._match(host, addr) return (host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, x509_subjects, revoked_subjects) def import_known_hosts(data): """Import SSH known hosts This function imports known host patterns and keys in OpenSSH known hosts format. :param str data: The known hosts data to import :returns: An :class:`SSHKnownHosts` object """ return SSHKnownHosts(data) def read_known_hosts(filename): """Read SSH known hosts from a file This function reads known host patterns and keys in OpenSSH known hosts format from a file. :param str filename: The file to read the known hosts from :returns: An :class:`SSHKnownHosts` object """ with open(filename, 'r') as f: return import_known_hosts(f.read()) def match_known_hosts(known_hosts, host, addr, port): """Match a host, IP address, and port against a known_hosts list This function looks up a host, IP address, and port in a list of host patterns in OpenSSH ``known_hosts`` format and returns the host keys, CA keys, and revoked keys which match. The ``known_hosts`` argument can be any of the following: * a string containing the filename to load host patterns from * a byte string containing host pattern data to load * an already loaded :class:`SSHKnownHosts` object containing host patterns to match against * an alternate matching function which accepts a host, address, and port and returns lists of trusted host keys, trusted CA keys, and revoked keys to load * lists of trusted host keys, trusted CA keys, and revoked keys to load without doing any matching If the port is not the default port and no match is found for it, the lookup is attempted again without a port number. :param known_hosts: The host patterns to match against :param str host: The hostname of the target host :param str addr: The IP address of the target host :param int port: The port number on the target host, or ``None`` for the default :returns: A tuple of matching host keys, CA keys, and revoked keys """ if isinstance(known_hosts, str): known_hosts = read_known_hosts(known_hosts) elif isinstance(known_hosts, bytes): known_hosts = import_known_hosts(known_hosts.decode()) if isinstance(known_hosts, SSHKnownHosts): known_hosts = known_hosts.match(host, addr, port) else: if callable(known_hosts): known_hosts = known_hosts(host, addr, port) known_hosts = (tuple(map(load_public_keys, known_hosts[:3])) + tuple(map(load_certificates, known_hosts[3:5])) + tuple(map(_load_subject_names, known_hosts[5:]))) if len(known_hosts) == 3: # Provide backward compatibility for pre-X.509 releases known_hosts = tuple(known_hosts) + ((), (), (), ()) return known_hosts asyncssh-1.11.1/asyncssh/listener.py000066400000000000000000000174021320320510200174500ustar00rootroot00000000000000# Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH listeners""" import asyncio import errno import socket from .forward import SSHLocalPortForwarder, SSHLocalPathForwarder class SSHListener(asyncio.AbstractServer): """SSH listener for inbound connections""" def get_port(self): """Return the port number being listened on This method returns the port number that the remote listener was bound to. When the requested remote listening port is ``0`` to indicate a dynamic port, this method can be called to determine what listening port was selected. This function only applies to TCP listeners. :returns: The port number being listened on """ # pylint: disable=no-self-use return 0 def close(self): """Stop listening for new connections This method can be called to stop listening for connections. Existing connections will remain open. """ raise NotImplementedError @asyncio.coroutine def wait_closed(self): """Wait for the listener to close This method is a coroutine which waits for the associated listeners to be closed. """ raise NotImplementedError class SSHClientListener(SSHListener): """Client listener used to accept inbound forwarded connections""" def __init__(self, conn, loop, session_factory, encoding, window, max_pktsize): self._conn = conn self._session_factory = session_factory self._encoding = encoding self._window = window self._max_pktsize = max_pktsize self._close_event = asyncio.Event(loop=loop) @asyncio.coroutine def _close(self): """Close this listener""" self._close_event.set() self._conn = None def close(self): """Close this listener asynchronously""" if self._conn: self._conn.create_task(self._close()) @asyncio.coroutine def wait_closed(self): """Wait for this listener to finish closing""" yield from self._close_event.wait() class SSHTCPClientListener(SSHClientListener): """Client listener used to accept inbound forwarded TCP connections""" def __init__(self, conn, loop, session_factory, listen_host, listen_port, encoding, window, max_pktsize): super().__init__(conn, loop, session_factory, encoding, window, max_pktsize) self._listen_host = listen_host self._listen_port = listen_port @asyncio.coroutine def _close(self): """Close this listener""" if self._conn: # pragma: no branch yield from self._conn.close_client_tcp_listener(self._listen_host, self._listen_port) yield from super()._close() def process_connection(self, orig_host, orig_port): """Process a forwarded TCP connection""" chan = self._conn.create_tcp_channel(self._encoding, self._window, self._max_pktsize) chan.set_inbound_peer_names(self._listen_host, self._listen_port, orig_host, orig_port) return chan, self._session_factory(orig_host, orig_port) def get_port(self): """Return the port number being listened on""" return self._listen_port class SSHUNIXClientListener(SSHClientListener): """Client listener used to accept inbound forwarded UNIX connections""" def __init__(self, conn, loop, session_factory, listen_path, encoding, window, max_pktsize): super().__init__(conn, loop, session_factory, encoding, window, max_pktsize) self._listen_path = listen_path @asyncio.coroutine def _close(self): """Close this listener""" if self._conn: # pragma: no branch yield from self._conn.close_client_unix_listener(self._listen_path) yield from super()._close() def process_connection(self): """Process a forwarded UNIX connection""" chan = self._conn.create_unix_channel(self._encoding, self._window, self._max_pktsize) chan.set_inbound_peer_names(self._listen_path) return chan, self._session_factory() class SSHForwardListener(SSHListener): """A listener used when forwarding traffic from local ports""" def __init__(self, servers, listen_port=0): self._servers = servers self._listen_port = listen_port def get_port(self): """Return the port number being listened on""" return self._listen_port def close(self): """Close this listener""" for server in self._servers: server.close() @asyncio.coroutine def wait_closed(self): """Wait for this listener to finish closing""" for server in self._servers: yield from server.wait_closed() self._servers = [] @asyncio.coroutine def create_tcp_forward_listener(conn, loop, coro, listen_host, listen_port): """Create a listener to forward traffic from local ports over SSH""" def protocol_factory(): """Start a port forwarder for each new local connection""" return SSHLocalPortForwarder(conn, coro) if listen_host == '': listen_host = None addrinfo = yield from loop.getaddrinfo(listen_host, listen_port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM, flags=socket.AI_PASSIVE) if not addrinfo: # pragma: no cover raise OSError('getaddrinfo() returned empty list') servers = [] for family, socktype, proto, _, sa in addrinfo: try: sock = socket.socket(family, socktype, proto) except OSError: # pragma: no cover continue sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) if family == socket.AF_INET6: try: sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) except AttributeError: # pragma: no cover pass if sa[1] == 0: sa = sa[:1] + (listen_port,) + sa[2:] try: sock.bind(sa) except (OSError, OverflowError) as exc: sock.close() for server in servers: server.close() if isinstance(exc, OverflowError): # pragma: no cover exc.errno = errno.EOVERFLOW exc.strerror = str(exc) raise OSError(exc.errno, 'error while attempting to bind on ' 'address %r: %s' % (sa, exc.strerror)) from None if listen_port == 0: listen_port = sock.getsockname()[1] server = yield from loop.create_server(protocol_factory, sock=sock) servers.append(server) return SSHForwardListener(servers, listen_port) @asyncio.coroutine def create_unix_forward_listener(conn, loop, coro, listen_path): """Create a listener to forward a local UNIX domain socket over SSH""" def protocol_factory(): """Start a path forwarder for each new local connection""" return SSHLocalPathForwarder(conn, coro) server = yield from loop.create_unix_server(protocol_factory, listen_path) return SSHForwardListener([server]) asyncssh-1.11.1/asyncssh/logging.py000066400000000000000000000007601320320510200172500ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Sam Crooks - initial implementation # Ron Frederick - minor cleanup """Logging functions""" import logging logger = logging.getLogger(__package__) asyncssh-1.11.1/asyncssh/mac.py000066400000000000000000000104401320320510200163560ustar00rootroot00000000000000# Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH message authentication handlers""" import hmac from hashlib import md5, sha1, sha256, sha512 from .packet import UInt32, UInt64 try: from .crypto import umac64, umac128 _umac_available = True except ImportError: # pragma: no cover _umac_available = False _OPENSSH = b'@openssh.com' _ETM = b'-etm' + _OPENSSH _mac_algs = [] _mac_params = {} _mac_handlers = {} class _HMAC: """Parent class for HMAC-based SSH message authentication handlers""" def __init__(self, key, hash_size, hash_alg): self._key = key self._hash_size = hash_size self._hash_alg = hash_alg def sign(self, seq, packet): """Compute a signature for a message""" data = UInt32(seq) + packet sig = hmac.new(self._key, data, self._hash_alg).digest() return sig[:self._hash_size] def verify(self, seq, packet, sig): """Verify the signature of a message""" return self.sign(seq, packet) == sig class _UMAC: """Parent class for UMAC-based SSH message authentication handlers""" def __init__(self, key, hash_size, umac_alg): # pylint: disable=unused-argument self._key = key self._umac_alg = umac_alg def sign(self, seq, packet): """Compute a signature for a message""" return self._umac_alg(self._key, packet, UInt64(seq)).digest() def verify(self, seq, packet, sig): """Verify the signature of a message""" return self.sign(seq, packet) == sig def register_mac_alg(alg, key_size, hash_size, etm, mac_alg, *args): """Register a MAC algorithm""" _mac_algs.append(alg) _mac_params[alg] = (key_size, hash_size, etm) _mac_handlers[alg] = (mac_alg, hash_size, args) def get_mac_algs(): """Return a list of available MAC algorithms""" return _mac_algs def get_mac_params(alg): """Get parameters of a MAC algorithm This function returns the key and hash sizes of a MAC algorithm and whether or not to compute the MAC before or after encryption. """ return _mac_params[alg] def get_mac(alg, key): """Return a MAC handler This function returns a MAC object initialized with the specified kev that can be used for data signing and verification. """ mac_alg, hash_size, args = _mac_handlers[alg] return mac_alg(key, hash_size, *args) # pylint: disable=bad-whitespace if _umac_available: # pragma: no branch register_mac_alg(b'umac-64' + _ETM, 16, 8, True, _UMAC, umac64) register_mac_alg(b'umac-128' + _ETM, 16, 16, True, _UMAC, umac128) register_mac_alg(b'hmac-sha2-256' + _ETM, 32, 32, True, _HMAC, sha256) register_mac_alg(b'hmac-sha2-512' + _ETM, 64, 64, True, _HMAC, sha512) register_mac_alg(b'hmac-sha1' + _ETM, 20, 20, True, _HMAC, sha1) register_mac_alg(b'hmac-md5' + _ETM, 16, 16, True, _HMAC, md5) register_mac_alg(b'hmac-sha2-256-96' + _ETM, 32, 12, True, _HMAC, sha256) register_mac_alg(b'hmac-sha2-512-96' + _ETM, 64, 12, True, _HMAC, sha512) register_mac_alg(b'hmac-sha1-96' + _ETM, 20, 12, True, _HMAC, sha1) register_mac_alg(b'hmac-md5-96' + _ETM, 16, 12, True, _HMAC, md5) if _umac_available: # pragma: no branch register_mac_alg(b'umac-64' + _OPENSSH, 16, 8, False, _UMAC, umac64) register_mac_alg(b'umac-128' + _OPENSSH, 16, 16, False, _UMAC, umac128) register_mac_alg(b'hmac-sha2-256', 32, 32, False, _HMAC, sha256) register_mac_alg(b'hmac-sha2-512', 64, 64, False, _HMAC, sha512) register_mac_alg(b'hmac-sha1', 20, 20, False, _HMAC, sha1) register_mac_alg(b'hmac-md5', 16, 16, False, _HMAC, md5) register_mac_alg(b'hmac-sha2-256-96', 32, 12, False, _HMAC, sha256) register_mac_alg(b'hmac-sha2-512-96', 64, 12, False, _HMAC, sha512) register_mac_alg(b'hmac-sha1-96', 20, 12, False, _HMAC, sha1) register_mac_alg(b'hmac-md5-96', 16, 12, False, _HMAC, md5) asyncssh-1.11.1/asyncssh/misc.py000066400000000000000000000224751320320510200165640ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Miscellaneous utility classes and functions""" import asyncio import functools import ipaddress import os import platform import socket from collections import OrderedDict from random import SystemRandom import asyncssh from .constants import DEFAULT_LANG # Provide globals to test if we're on various Python versions python344 = platform.python_version_tuple() >= ('3', '4', '4') python35 = platform.python_version_tuple() >= ('3', '5', '0') python352 = platform.python_version_tuple() >= ('3', '5', '2') # Define a version of randrange which is based on SystemRandom(), so that # we get back numbers suitable for cryptographic use. _random = SystemRandom() randrange = _random.randrange # Avoid deprecation warning for asyncio.async() if python344: create_task = asyncio.ensure_future else: # pragma: no cover create_task = asyncio.async # pylint: disable=no-member def all_ints(seq): """Return if a sequence contains all integers""" return all(isinstance(i, int) for i in seq) # Default file names in .ssh directory to read private keys from _DEFAULT_KEY_FILES = ('id_ed25519', 'id_ecdsa', 'id_rsa', 'id_dsa') def load_default_keypairs(passphrase=None): """Return a list of default keys from the user's home directory""" result = [] for file in _DEFAULT_KEY_FILES: try: file = os.path.join(os.path.expanduser('~'), '.ssh', file) result.extend(asyncssh.load_keypairs(file, passphrase)) except asyncssh.KeyImportError as exc: # Ignore encrypted default keys if a passphrase isn't provided if not str(exc).startswith('Passphrase'): raise except OSError: pass return result # Punctuation to map when creating handler names _HANDLER_PUNCTUATION = (('@', '_at_'), ('.', '_dot_'), ('-', '_')) def map_handler_name(name): """Map punctuation so a string can be used as a handler name""" for old, new in _HANDLER_PUNCTUATION: name = name.replace(old, new) return name def _normalize_scoped_ip(addr): """Normalize scoped IP address The ipaddress module doesn't handle scoped addresses properly, so we strip off the CIDR suffix here and normalize scoped IP addresses using socket.inet_pton before we pass them into ipaddress. """ for family in (socket.AF_INET, socket.AF_INET6): try: return socket.inet_ntop(family, socket.inet_pton(family, addr)) except (ValueError, socket.error): pass return addr def ip_address(addr): """Wrapper for ipaddress.ip_address which supports scoped addresses""" return ipaddress.ip_address(_normalize_scoped_ip(addr)) def ip_network(addr): """Wrapper for ipaddress.ip_network which supports scoped addresses""" idx = addr.find('/') if idx >= 0: addr, mask = addr[:idx], addr[idx:] else: mask = '' return ipaddress.ip_network(_normalize_scoped_ip(addr) + mask) if python352: async_iterator = lambda iter: iter else: async_iterator = asyncio.coroutine def async_context_manager(coro): """Decorator for methods returning asynchronous context managers This function can be used as a decorator for coroutines which return objects intended to be used as Python 3.5 asynchronous context managers. The object returned should implement __aenter__ and __aexit__ methods to run when the async context is entered and exited. This wrapper also allows non-async context managers to be defined on the returned object, as well as the use of "await" or "yield from" on the function being decorated for backward compatibility with the API defined by older versions of AsyncSSH. """ class AsyncContextManager: """Async context manager wrapper for Python 3.5 and later""" def __init__(self, coro): self._coro = coro self._result = None def __iter__(self): return (yield from self._coro) def __await__(self): return (yield from self._coro) @asyncio.coroutine def __aenter__(self): self._result = yield from self._coro return (yield from self._result.__aenter__()) @asyncio.coroutine def __aexit__(self, *exc_info): yield from self._result.__aexit__(*exc_info) self._result = None @functools.wraps(coro) def coro_wrapper(*args, **kwargs): """Return an async context manager wrapper for this coroutine""" return AsyncContextManager(asyncio.coroutine(coro)(*args, **kwargs)) if python35: return coro_wrapper else: return coro class Record: """General-purpose record type with fixed set of fields""" __slots__ = OrderedDict() def __init__(self, *args, **kwargs): for k, v in self.__slots__.items(): setattr(self, k, v) for k, v in zip(self.__slots__, args): setattr(self, k, v) for k, v in kwargs.items(): setattr(self, k, v) def __repr__(self): return '%s(%s)' % (type(self).__name__, ', '.join('%s=%r' % (k, getattr(self, k)) for k in self.__slots__)) class Error(Exception): """General SSH error""" def __init__(self, errtype, code, reason, lang): super().__init__('%s Error: %s' % (errtype, reason)) self.code = code self.reason = reason self.lang = lang class DisconnectError(Error): """SSH disconnect error This exception is raised when a serious error occurs which causes the SSH connection to be disconnected. Exception codes should be taken from :ref:`disconnect reason codes `. :param int code: Disconnect reason, taken from :ref:`disconnect reason codes ` :param str reason: A human-readable reason for the disconnect :param str lang: The language the reason is in """ def __init__(self, code, reason, lang=DEFAULT_LANG): super().__init__('Disconnect', code, reason, lang) class ChannelOpenError(Error): """SSH channel open error This exception is raised by connection handlers to report channel open failures. :param int code: Channel open failure reason, taken from :ref:`channel open failure reason codes ` :param str reason: A human-readable reason for the channel open failure :param str lang: The language the reason is in """ def __init__(self, code, reason, lang=DEFAULT_LANG): super().__init__('Channel Open', code, reason, lang) class PasswordChangeRequired(Exception): """SSH password change required This exception is raised during password validation on the server to indicate that a password change is required. It shouuld be raised when the password provided is valid but expired, to trigger the client to provide a new password. :param str prompt: The prompt requesting that the user enter a new password :param str lang: The language that the prompt is in """ def __init__(self, prompt, lang=DEFAULT_LANG): super().__init__('Password change required: %s' % prompt) self.prompt = prompt self.lang = lang class BreakReceived(Exception): """SSH break request received This exception is raised on an SSH server stdin stream when the client sends a break on the channel. :param int msec: The duration of the break in milliseconds """ def __init__(self, msec): super().__init__('Break for %s msec' % msec) self.msec = msec class SignalReceived(Exception): """SSH signal request received This exception is raised on an SSH server stdin stream when the client sends a signal on the channel. :param str signal: The name of the signal sent by the client """ def __init__(self, signal): super().__init__('Signal: %s' % signal) self.signal = signal class TerminalSizeChanged(Exception): """SSH terminal size change notification received This exception is raised on an SSH server stdin stream when the client sends a terminal size change on the channel. :param int width: The new terminal width :param int height: The new terminal height :param int pixwidth: The new terminal width in pixels :param int pixheight: The new terminal height in pixels """ def __init__(self, width, height, pixwidth, pixheight): super().__init__('Terminal size change: (%s, %s, %s, %s)' % (width, height, pixwidth, pixheight)) self.width = width self.height = height self.pixwidth = pixwidth self.pixheight = pixheight asyncssh-1.11.1/asyncssh/packet.py000066400000000000000000000103711320320510200170700ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH packet encoding and decoding functions""" class PacketDecodeError(ValueError): """Packet decoding error""" def Byte(value): """Encode a single byte""" return bytes((value,)) def Boolean(value): """Encode a boolean value""" return Byte(bool(value)) def UInt32(value): """Encode a 32-bit integer value""" return value.to_bytes(4, 'big') def UInt64(value): """Encode a 64-bit integer value""" return value.to_bytes(8, 'big') def String(value): """Encode a byte string or UTF-8 string value""" if isinstance(value, str): value = value.encode('utf-8', errors='strict') return len(value).to_bytes(4, 'big') + value def MPInt(value): """Encode a multiple precision integer value""" l = value.bit_length() l += (l % 8 == 0 and value != 0 and value != -1 << (l - 1)) l = (l + 7) // 8 return l.to_bytes(4, 'big') + value.to_bytes(l, 'big', signed=True) def NameList(value): """Encode a comma-separated list of byte strings""" return String(b','.join(value)) class SSHPacket: """Decoder class for SSH packets""" def __init__(self, packet): self._packet = packet self._idx = 0 self._len = len(packet) def __bool__(self): return self._idx != self._len def check_end(self): """Confirm that all of the data in the packet has been consumed""" if self: raise PacketDecodeError('Unexpected data at end of packet') def get_consumed_payload(self): """Return the portion of the packet consumed so far""" return self._packet[:self._idx] def get_remaining_payload(self): """Return the portion of the packet not yet consumed""" return self._packet[self._idx:] def get_bytes(self, size): """Extract the requested number of bytes from the packet""" if self._idx + size > self._len: raise PacketDecodeError('Incomplete packet') value = self._packet[self._idx:self._idx+size] self._idx += size return value def get_byte(self): """Extract a single byte from the packet""" return self.get_bytes(1)[0] def get_boolean(self): """Extract a boolean from the packet""" return bool(self.get_byte()) def get_uint32(self): """Extract a 32-bit integer from the packet""" return int.from_bytes(self.get_bytes(4), 'big') def get_uint64(self): """Extract a 64-bit integer from the packet""" return int.from_bytes(self.get_bytes(8), 'big') def get_string(self): """Extract a UTF-8 string from the packet""" return self.get_bytes(self.get_uint32()) def get_mpint(self): """Extract a multiple precision integer from the packet""" return int.from_bytes(self.get_string(), 'big', signed=True) def get_namelist(self): """Extract a comma-separated list of byte strings from the packet""" namelist = self.get_string() return namelist.split(b',') if namelist else [] class SSHPacketHandler: """Parent class for SSH packet handlers Classes wishing to decode SSH packets can inherit from this class, defining the class variable packet_handlers as a dictionary which maps SSH packet types to handler methods in the class and then calling process_packet() to run the corresponding packet handler. The process_packet() function will return True if a handler was found and False otherwise. """ packet_handlers = {} def process_packet(self, pkttype, packet): """Call the packet handler defined for the specified packet. Return True if a handler was found, or False otherwise.""" if pkttype in self.packet_handlers: self.packet_handlers[pkttype](self, pkttype, packet) return True else: return False asyncssh-1.11.1/asyncssh/pattern.py000066400000000000000000000077031320320510200173030ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Pattern matching for principal and host names""" from fnmatch import fnmatch from .misc import ip_network class WildcardPattern: """A pattern matcher for '*' and '?' wildcards""" def __init__(self, pattern): # We need to escape square brackets in host patterns if we # want to use Python's fnmatch. self._pattern = ''.join('[[]' if ch == '[' else '[]]' if ch == ']' else ch for ch in pattern) def matches(self, value): """Return whether a wild card pattern matches a value""" return fnmatch(value, self._pattern) class WildcardHostPattern(WildcardPattern): """Match a host name or address against a wildcard pattern""" def matches(self, host, addr, ip): """Return whether a host or address matches a wild card host pattern""" # Arguments vary by class, but inheritance is still needed here # IP matching is only done for CIDRHostPattern # pylint: disable=arguments-differ,unused-argument return (host and super().matches(host)) or \ (addr and super().matches(addr)) class CIDRHostPattern: """Match IPv4/v6 address against CIDR-style subnet pattern""" def __init__(self, pattern): self._network = ip_network(pattern) def matches(self, host, addr, ip): """Return whether an IP address matches a CIDR address pattern""" # Host & addr matching is only done for WildcardHostPattern # pylint: disable=unused-argument return ip and ip in self._network class _PatternList: """Match against a list of comma-separated positive and negative patterns This class is a base class for building a pattern matcher that takes a set of comma-separated positive and negative patterns, returning ``True`` if one or more positive patterns match and no negative ones do. The pattern matching is done by objects returned by the build_pattern method. The arguments passed in when a match is performed will vary depending on what class build_pattern returns. """ def __init__(self, patterns): self._pos_patterns = [] self._neg_patterns = [] for pattern in patterns.split(','): if pattern.startswith('!'): negate = True pattern = pattern[1:] else: negate = False matcher = self.build_pattern(pattern) if negate: self._neg_patterns.append(matcher) else: self._pos_patterns.append(matcher) def build_pattern(self, pattern): """Abstract method to build a pattern object""" raise NotImplementedError def matches(self, *args): """Match a set of values against positive & negative pattern lists""" pos_match = any(p.matches(*args) for p in self._pos_patterns) neg_match = any(p.matches(*args) for p in self._neg_patterns) return pos_match and not neg_match class WildcardPatternList(_PatternList): """Match names against wildcard patterns""" def build_pattern(self, pattern): """Build a wild card pattern""" return WildcardPattern(pattern) class HostPatternList(_PatternList): """Match host names & addresses against wildcard and CIDR patterns""" def build_pattern(self, pattern): """Build a CIDR address or wild card host pattern""" try: return CIDRHostPattern(pattern) except ValueError: return WildcardHostPattern(pattern) asyncssh-1.11.1/asyncssh/pbe.py000066400000000000000000000434031320320510200163710ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Asymmetric key password based encryption functions""" import os from hashlib import md5, sha1 from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode from .crypto import lookup_cipher, pbkdf2_hmac # pylint: disable=bad-whitespace _ES1_MD5_DES = ObjectIdentifier('1.2.840.113549.1.5.3') _ES1_SHA1_DES = ObjectIdentifier('1.2.840.113549.1.5.10') _ES2 = ObjectIdentifier('1.2.840.113549.1.5.13') _P12_RC4_128 = ObjectIdentifier('1.2.840.113549.1.12.1.1') _P12_RC4_40 = ObjectIdentifier('1.2.840.113549.1.12.1.2') _P12_DES3 = ObjectIdentifier('1.2.840.113549.1.12.1.3') _P12_DES2 = ObjectIdentifier('1.2.840.113549.1.12.1.4') _ES2_CAST128 = ObjectIdentifier('1.2.840.113533.7.66.10') _ES2_DES3 = ObjectIdentifier('1.2.840.113549.3.7') _ES2_BF = ObjectIdentifier('1.3.6.1.4.1.3029.1.2') _ES2_DES = ObjectIdentifier('1.3.14.3.2.7') _ES2_AES128 = ObjectIdentifier('2.16.840.1.101.3.4.1.2') _ES2_AES192 = ObjectIdentifier('2.16.840.1.101.3.4.1.22') _ES2_AES256 = ObjectIdentifier('2.16.840.1.101.3.4.1.42') _ES2_PBKDF2 = ObjectIdentifier('1.2.840.113549.1.5.12') _ES2_SHA1 = ObjectIdentifier('1.2.840.113549.2.7') _ES2_SHA224 = ObjectIdentifier('1.2.840.113549.2.8') _ES2_SHA256 = ObjectIdentifier('1.2.840.113549.2.9') _ES2_SHA384 = ObjectIdentifier('1.2.840.113549.2.10') _ES2_SHA512 = ObjectIdentifier('1.2.840.113549.2.11') # pylint: enable=bad-whitespace _pkcs1_ciphers = {} _pkcs8_ciphers = {} _pbes2_ciphers = {} _pbes2_kdfs = {} _pbes2_prfs = {} _pkcs1_cipher_names = {} _pkcs8_cipher_suites = {} _pbes2_cipher_names = {} _pbes2_kdf_names = {} _pbes2_prf_names = {} class KeyEncryptionError(ValueError): """Key encryption error This exception is raised by key decryption functions when the data provided is not a valid encrypted private key. """ class _RFC1423Pad: """RFC 1423 padding functions This class implements RFC 1423 padding for encryption and decryption of data by block ciphers. On encryption, the data is padded by between 1 and the cipher's block size number of bytes, with the padding value being equal to the length of the padding. """ def __init__(self, cipher): self._cipher = cipher self._block_size = cipher.block_size def encrypt(self, data): """Pad data before encrypting it""" pad = self._block_size - (len(data) % self._block_size) data += pad * bytes((pad,)) return self._cipher.encrypt(data) def decrypt(self, data): """Remove padding from data after decrypting it""" data = self._cipher.decrypt(data) if data: pad = data[-1] if (1 <= pad <= self._block_size and data[-pad:] == pad * bytes((pad,))): return data[:-pad] raise KeyEncryptionError('Unable to decrypt key') def _pbkdf1(hash_alg, passphrase, salt, count, key_size): """PKCS#5 v1.5 key derivation function for password-based encryption This function implements the PKCS#5 v1.5 algorithm for deriving an encryption key from a passphrase and salt. The standard PBKDF1 function cannot generate more key bytes than the hash digest size, but 3DES uses a modified form of it which calls PBKDF1 recursively on the result to generate more key data. Support for this is implemented here. """ if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') key = passphrase + salt for _ in range(count): key = hash_alg(key).digest() if len(key) <= key_size: return key + _pbkdf1(hash_alg, key + passphrase, salt, count, key_size - len(key)) else: return key[:key_size] def _pbkdf_p12(hash_alg, passphrase, salt, count, key_size, idx): """PKCS#12 key derivation function for password-based encryption This function implements the PKCS#12 algorithm for deriving an encryption key from a passphrase and salt. """ # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name def _make_block(data, v): """Make a block a multiple of v bytes long by repeating data""" l = len(data) size = ((l + v - 1) // v) * v return (((size + l - 1) // l) * data)[:size] v = hash_alg().block_size D = v * bytes((idx,)) if isinstance(passphrase, str): passphrase = passphrase.encode('utf-16be') I = bytearray(_make_block(salt, v) + _make_block(passphrase + b'\0\0', v)) key = b'' while len(key) < key_size: A = D + I for i in range(count): A = hash_alg(A).digest() B = int.from_bytes(_make_block(A, v), 'big') for i in range(0, len(I), v): x = (int.from_bytes(I[i:i+v], 'big') + B + 1) % (1 << v*8) I[i:i+v] = x.to_bytes(v, 'big') key += A return key[:key_size] def _pbes1(params, passphrase, hash_alg, cipher, key_size): """PKCS#5 v1.5 cipher selection function for password-based encryption This function implements the PKCS#5 v1.5 algorithm for password-based encryption. It returns a cipher object which can be used to encrypt or decrypt data based on the specified encryption parameters, passphrase, and salt. """ if (not isinstance(params, tuple) or len(params) != 2 or not isinstance(params[0], bytes) or not isinstance(params[1], int)): raise KeyEncryptionError('Invalid PBES1 encryption parameters') salt, count = params key = _pbkdf1(hash_alg, passphrase, salt, count, key_size + cipher.block_size) key, iv = key[:key_size], key[key_size:] return _RFC1423Pad(cipher.new(key, iv)) def _pbe_p12(params, passphrase, hash_alg, cipher, key_size): """PKCS#12 cipher selection function for password-based encryption This function implements the PKCS#12 algorithm for password-based encryption. It returns a cipher object which can be used to encrypt or decrypt data based on the specified encryption parameters, passphrase, and salt. """ if (not isinstance(params, tuple) or len(params) != 2 or not isinstance(params[0], bytes) or len(params[0]) == 0 or not isinstance(params[1], int) or params[1] == 0): raise KeyEncryptionError('Invalid PBES1 PKCS#12 encryption parameters') salt, count = params key = _pbkdf_p12(hash_alg, passphrase, salt, count, key_size, 1) if cipher.cipher_name == 'arc4': cipher = cipher.new(key) else: iv = _pbkdf_p12(hash_alg, passphrase, salt, count, cipher.block_size, 2) cipher = _RFC1423Pad(cipher.new(key, iv)) return cipher def _pbes2_iv(params, key, cipher): """PKCS#5 v2.0 handler for PBES2 ciphers with an IV as a parameter This function returns the appropriate cipher object to use for PBES2 encryption for ciphers that have only an IV as an encryption parameter. """ if len(params) != 1 or not isinstance(params[0], bytes): raise KeyEncryptionError('Invalid PBES2 encryption parameters') if len(params[0]) != cipher.block_size: raise KeyEncryptionError('Invalid length IV for PBES2 encryption') return cipher.new(key, params[0]) def _pbes2_pbkdf2(params, passphrase, default_key_size): """PKCS#5 v2.0 handler for PBKDF2 key derivation This function parses the PBKDF2 arguments from a PKCS#8 encrypted key and returns the encryption key to use for encryption. """ if (len(params) != 1 or not isinstance(params[0], tuple) or len(params[0]) < 2): raise KeyEncryptionError('Invalid PBES2 key derivation parameters') params = list(params[0]) if not isinstance(params[0], bytes) or not isinstance(params[1], int): raise KeyEncryptionError('Invalid PBES2 key derivation parameters') salt = params.pop(0) count = params.pop(0) if params and isinstance(params[0], int): key_size = params.pop(0) # pragma: no cover, used only by RC2 else: key_size = default_key_size if params: if (isinstance(params[0], tuple) and len(params[0]) == 2 and isinstance(params[0][0], ObjectIdentifier)): prf_alg = params[0][0] if prf_alg in _pbes2_prfs: hash_name = _pbes2_prfs[prf_alg] else: raise KeyEncryptionError('Unknown PBES2 pseudo-random ' 'function') else: raise KeyEncryptionError('Invalid PBES2 pseudo-random function ' 'parameters') else: hash_name = 'sha1' if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') return pbkdf2_hmac(hash_name, passphrase, salt, count, key_size) def _pbes2(params, passphrase): """PKCS#5 v2.0 cipher selection function for password-based encryption This function implements the PKCS#5 v2.0 algorithm for password-based encryption. It returns a cipher object which can be used to encrypt or decrypt data based on the specified encryption parameters and passphrase. """ if (not isinstance(params, tuple) or len(params) != 2 or not isinstance(params[0], tuple) or len(params[0]) < 1 or not isinstance(params[1], tuple) or len(params[1]) < 1): raise KeyEncryptionError('Invalid PBES2 encryption parameters') kdf_params = list(params[0]) kdf_alg = kdf_params.pop(0) if kdf_alg not in _pbes2_kdfs: raise KeyEncryptionError('Unknown PBES2 key derivation function') enc_params = list(params[1]) enc_alg = enc_params.pop(0) if enc_alg not in _pbes2_ciphers: raise KeyEncryptionError('Unknown PBES2 encryption algorithm') kdf_handler, kdf_args = _pbes2_kdfs[kdf_alg] enc_handler, cipher, default_key_size = _pbes2_ciphers[enc_alg] key = kdf_handler(kdf_params, passphrase, default_key_size, *kdf_args) return _RFC1423Pad(enc_handler(enc_params, key, cipher)) def register_pkcs1_cipher(cipher_name, alg, cipher, mode, key_size): """Register a cipher used for PKCS#1 private key encryption""" cipher = lookup_cipher(cipher, mode) if cipher: # pragma: no branch _pkcs1_ciphers[alg] = (cipher, key_size) _pkcs1_cipher_names[cipher_name] = alg def register_pkcs8_cipher(cipher_name, hash_name, alg, handler, hash_alg, cipher, mode, key_size): """Register a cipher used for PKCS#8 private key encryption""" cipher = lookup_cipher(cipher, mode) if cipher: # pragma: no branch _pkcs8_ciphers[alg] = (handler, hash_alg, cipher, key_size) _pkcs8_cipher_suites[cipher_name, hash_name] = alg def register_pbes2_cipher(cipher_name, alg, handler, cipher, mode, key_size): """Register a PBES2 encryption algorithm""" cipher = lookup_cipher(cipher, mode) if cipher: # pragma: no branch _pbes2_ciphers[alg] = (handler, cipher, key_size) _pbes2_cipher_names[cipher_name] = (alg, key_size) def register_pbes2_kdf(kdf_name, alg, handler, *args): """Register a PBES2 key derivation function""" _pbes2_kdfs[alg] = (handler, args) _pbes2_kdf_names[kdf_name] = alg def register_pbes2_prf(hash_name, alg): """Register a PBES2 pseudo-random function""" _pbes2_prfs[alg] = hash_name _pbes2_prf_names[hash_name] = alg def pkcs1_encrypt(data, cipher, passphrase): """Encrypt PKCS#1 key data This function encrypts PKCS#1 key data using the specified cipher and passphrase. Available ciphers include: aes128-cbc, aes192-cbc, aes256-cbc, des-cbc, des3-cbc """ if cipher in _pkcs1_cipher_names: alg = _pkcs1_cipher_names[cipher] cipher, key_size = _pkcs1_ciphers[alg] iv = os.urandom(cipher.block_size) key = _pbkdf1(md5, passphrase, iv[:8], 1, key_size) cipher = _RFC1423Pad(cipher.new(key, iv)) return alg, iv, cipher.encrypt(data) else: raise KeyEncryptionError('Unknown PKCS#1 encryption algorithm') def pkcs1_decrypt(data, alg, iv, passphrase): """Decrypt PKCS#1 key data This function decrypts PKCS#1 key data using the specified algorithm, initialization vector, and passphrase. The algorithm name and IV should be taken from the PEM DEK-Info header. """ if alg in _pkcs1_ciphers: cipher, key_size = _pkcs1_ciphers[alg] key = _pbkdf1(md5, passphrase, iv[:8], 1, key_size) cipher = _RFC1423Pad(cipher.new(key, iv)) return cipher.decrypt(data) else: raise KeyEncryptionError('Unknown PKCS#1 encryption algorithm') def pkcs8_encrypt(data, cipher_name, hash_name, version, passphrase): """Encrypt PKCS#8 key data This function encrypts PKCS#8 key data using the specified cipher, hash, encryption version, and passphrase. Available ciphers include: aes128-cbc, aes192-cbc, aes256-cbc, blowfish-cbc, cast128-cbc, des-cbc, des2-cbc, des3-cbc, rc4-40, and rc4-128 Available hashes include: md5, sha1, sha256, sha384, sha512 Available versions include 1 for PBES1 and 2 for PBES2. Only some combinations of cipher, hash, and version are supported. """ if version == 1 and (cipher_name, hash_name) in _pkcs8_cipher_suites: alg = _pkcs8_cipher_suites[cipher_name, hash_name] handler, hash_alg, cipher, key_size = _pkcs8_ciphers[alg] params = (os.urandom(8), 2048) cipher = handler(params, passphrase, hash_alg, cipher, key_size) return der_encode(((alg, params), cipher.encrypt(data))) elif version == 2 and cipher_name in _pbes2_cipher_names: enc_alg, key_size = _pbes2_cipher_names[cipher_name] _, cipher, _ = _pbes2_ciphers[enc_alg] kdf_params = [os.urandom(8), 2048] iv = os.urandom(cipher.block_size) enc_params = (enc_alg, iv) if hash_name != 'sha1': if hash_name in _pbes2_prf_names: kdf_params.append((_pbes2_prf_names[hash_name], None)) else: raise KeyEncryptionError('Unknown PBES2 hash function') alg = _ES2 params = ((_ES2_PBKDF2, tuple(kdf_params)), enc_params) cipher = _pbes2(params, passphrase) else: raise KeyEncryptionError('Unknown PKCS#8 encryption algorithm') return der_encode(((alg, params), cipher.encrypt(data))) def pkcs8_decrypt(key_data, passphrase): """Decrypt PKCS#8 key data This function decrypts key data in PKCS#8 EncryptedPrivateKeyInfo format using the specified passphrase. """ if not isinstance(key_data, tuple) or len(key_data) != 2: raise KeyEncryptionError('Invalid PKCS#8 encrypted key format') alg_params, data = key_data if (not isinstance(alg_params, tuple) or len(alg_params) != 2 or not isinstance(data, bytes)): raise KeyEncryptionError('Invalid PKCS#8 encrypted key format') alg, params = alg_params if alg == _ES2: cipher = _pbes2(params, passphrase) elif alg in _pkcs8_ciphers: handler, hash_alg, cipher, key_size = _pkcs8_ciphers[alg] cipher = handler(params, passphrase, hash_alg, cipher, key_size) else: raise KeyEncryptionError('Unknown PKCS#8 encryption algorithm') try: return der_decode(cipher.decrypt(data)) except ASN1DecodeError: raise KeyEncryptionError('Invalid PKCS#8 encrypted key data') # pylint: disable=bad-whitespace _pkcs1_cipher_list = ( ('aes128-cbc', b'AES-128-CBC', 'aes', 'cbc', 16), ('aes192-cbc', b'AES-192-CBC', 'aes', 'cbc', 24), ('aes256-cbc', b'AES-256-CBC', 'aes', 'cbc', 32), ('des-cbc', b'DES-CBC', 'des', 'cbc', 8), ('des3-cbc', b'DES-EDE3-CBC', 'des3', 'cbc', 24) ) _pkcs8_cipher_list = ( ('des-cbc', 'md5', _ES1_MD5_DES, _pbes1, md5, 'des', 'cbc', 8), ('des-cbc', 'sha1', _ES1_SHA1_DES, _pbes1, sha1, 'des', 'cbc', 8), ('des2-cbc', 'sha1', _P12_DES2, _pbe_p12, sha1, 'des3', 'cbc', 16), ('des3-cbc', 'sha1', _P12_DES3, _pbe_p12, sha1, 'des3', 'cbc', 24), ('rc4-40', 'sha1', _P12_RC4_40, _pbe_p12, sha1, 'arc4', None, 5), ('rc4-128', 'sha1', _P12_RC4_128, _pbe_p12, sha1, 'arc4', None, 16) ) _pbes2_cipher_list = ( ('aes128-cbc', _ES2_AES128, _pbes2_iv, 'aes', 'cbc', 16), ('aes192-cbc', _ES2_AES192, _pbes2_iv, 'aes', 'cbc', 24), ('aes256-cbc', _ES2_AES256, _pbes2_iv, 'aes', 'cbc', 32), ('blowfish-cbc', _ES2_BF, _pbes2_iv, 'blowfish', 'cbc', 16), ('cast128-cbc', _ES2_CAST128, _pbes2_iv, 'cast', 'cbc', 16), ('des-cbc', _ES2_DES, _pbes2_iv, 'des', 'cbc', 8), ('des3-cbc', _ES2_DES3, _pbes2_iv, 'des3', 'cbc', 24) ) _pbes2_kdf_list = ( ('pbkdf2', _ES2_PBKDF2, _pbes2_pbkdf2), ) _pbes2_prf_list = ( ('sha1', _ES2_SHA1), ('sha224', _ES2_SHA224), ('sha256', _ES2_SHA256), ('sha384', _ES2_SHA384), ('sha512', _ES2_SHA512) ) for _args in _pkcs1_cipher_list: register_pkcs1_cipher(*_args) for _args in _pkcs8_cipher_list: register_pkcs8_cipher(*_args) for _args in _pbes2_cipher_list: register_pbes2_cipher(*_args) for _args in _pbes2_kdf_list: register_pbes2_kdf(*_args) for _args in _pbes2_prf_list: register_pbes2_prf(*_args) asyncssh-1.11.1/asyncssh/process.py000066400000000000000000001154131320320510200173020ustar00rootroot00000000000000# Copyright (c) 2016-2017 by Ron Frederick .else: # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH process handlers""" import asyncio from asyncio.subprocess import DEVNULL, PIPE, STDOUT from collections import OrderedDict import io import os import socket import stat from .constants import DEFAULT_LANG, DISC_PROTOCOL_ERROR, EXTENDED_DATA_STDERR from .misc import DisconnectError, Error, Record from .stream import SSHClientStreamSession, SSHServerStreamSession from .stream import SSHReader, SSHWriter def _is_regular_file(file): """Return if argument is a regular file or file-like object""" try: return stat.S_ISREG(os.fstat(file.fileno()).st_mode) except OSError: return True class _UnicodeReader: """Handle buffering partial Unicode data""" def __init__(self, encoding, textmode=False): self._encoding = encoding self._textmode = textmode self._partial = b'' def decode(self, data): """Decode Unicode bytes when reading from binary sources""" if self._encoding and not self._textmode: data = self._partial + data self._partial = b'' try: data = data.decode(self._encoding) except UnicodeDecodeError as exc: if exc.start > 0: # Avoid pylint false positive # pylint: disable=invalid-slice-index self._partial = data[exc.start:] data = data[:exc.start].decode(self._encoding) elif exc.reason == 'unexpected end of data': self._partial = data data = '' else: self.close() raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unicode decode error') return data def check_partial(self): """Check if there's partial Unicode data left at EOF""" if self._partial: self.close() raise DisconnectError(DISC_PROTOCOL_ERROR, 'Unicode decode error') def close(self): """Perform necessary cleanup on error (provided by derived classes)""" class _UnicodeWriter: """Handle encoding Unicode data before writing it""" def __init__(self, encoding, textmode=False): self._encoding = encoding self._textmode = textmode def encode(self, data): """Encode Unicode bytes when writing to binary targets""" if self._encoding and not self._textmode: data = data.encode(self._encoding) return data class _FileReader(_UnicodeReader): """Forward data from a file""" def __init__(self, process, file, bufsize, datatype, encoding): super().__init__(encoding, hasattr(file, 'encoding')) self._process = process self._file = file self._bufsize = bufsize self._datatype = datatype self._paused = False def feed(self): """Feed file data""" while not self._paused: data = self._file.read(self._bufsize) if data: self._process.feed_data(self.decode(data), self._datatype) else: self.check_partial() self._process.feed_eof(self._datatype) break def pause_reading(self): """Pause reading from the file""" self._paused = True def resume_reading(self): """Resume reading from the file""" self._paused = False self.feed() def close(self): """Stop forwarding data from the file""" self._file.close() class _FileWriter(_UnicodeWriter): """Forward data to a file""" def __init__(self, file, encoding): super().__init__(encoding, hasattr(file, 'encoding')) self._file = file def write(self, data): """Write data to the file""" self._file.write(self.encode(data)) def write_eof(self): """Close output file when end of file is received""" self.close() def close(self): """Stop forwarding data to the file""" self._file.close() class _PipeReader(_UnicodeReader, asyncio.Protocol): """Forward data from a pipe""" def __init__(self, process, datatype, encoding): super().__init__(encoding) self._process = process self._datatype = datatype self._transport = None def connection_made(self, transport): """Handle a newly opened pipe""" self._transport = transport def data_received(self, data): """Forward data from the pipe""" self._process.feed_data(self.decode(data), self._datatype) def eof_received(self): """Forward EOF from the pipe""" self.check_partial() self._process.feed_eof(self._datatype) def pause_reading(self): """Pause reading from the pipe""" self._transport.pause_reading() def resume_reading(self): """Resume reading from the pipe""" self._transport.resume_reading() def close(self): """Stop forwarding data from the pipe""" self._transport.close() class _PipeWriter(_UnicodeWriter, asyncio.BaseProtocol): """Forward data to a pipe""" def __init__(self, process, datatype, encoding): super().__init__(encoding) self._process = process self._datatype = datatype self._transport = None def connection_made(self, transport): """Handle a newly opened pipe""" self._transport = transport def pause_writing(self): """Pause writing to the pipe""" self._process.pause_feeding(self._datatype) def resume_writing(self): """Resume writing to the pipe""" self._process.resume_feeding(self._datatype) def write(self, data): """Write data to the pipe""" self._transport.write(self.encode(data)) def write_eof(self): """Write EOF to the pipe""" self._transport.write_eof() def close(self): """Stop forwarding data to the pipe""" self._transport.close() class _ProcessReader: """Forward data from another SSH process""" def __init__(self, process, datatype): self._process = process self._datatype = datatype def pause_reading(self): """Pause reading from the other channel""" self._process.pause_feeding(self._datatype) def resume_reading(self): """Resume reading from the other channel""" self._process.resume_feeding(self._datatype) def close(self): """Stop forwarding data from the other channel""" self._process.clear_writer(self._datatype) class _ProcessWriter: """Forward data to another SSH process""" def __init__(self, process, datatype): self._process = process self._datatype = datatype def write(self, data): """Write data to the other channel""" self._process.feed_data(data, self._datatype) def write_eof(self): """Write EOF to the other channel""" self._process.feed_eof(self._datatype) def close(self): """Stop forwarding data to the other channel""" self._process.clear_reader(self._datatype) class _DevNullWriter: """Discard data""" def write(self, data): """Discard data being written""" def write_eof(self): """Ignore end of file""" def close(self): """Ignore close""" class _StdoutWriter: """Forward data to an SSH process' stdout instead of stderr""" def __init__(self, process): self._process = process def write(self, data): """Pretend data was received on stdout""" self._process.data_received(data, None) def write_eof(self): """Ignore end of file""" def close(self): """Ignore close""" class ProcessError(Error): """SSH Process error This exception is raised when an :class:`SSHClientProcess` exits with a non-zero exit status and error checking is enabled. In addition to the usual error code, reason, and language, it contains the following fields: ============ ======================================= ================ Field Description Type ============ ======================================= ================ env The environment the client requested str or ``None`` to be set for the process command The command the client requested the str or ``None`` process to execute (if any) subsystem The subsystem the client requested the str or ``None`` process to open (if any) exit_status The exit status returned, or -1 if an int exit signal is sent exit_signal The exit signal sent (if any) in the tuple or ``None`` form of a tuple containing the signal name, a bool for whether a core dump occurred, a message associated with the signal, and the language the message was in stdout The output sent by the process to str or bytes stdout (if not redirected) stderr The output sent by the process to str or bytes stderr (if not redirected) ============ ======================================= ================ """ def __init__(self, env, command, subsystem, exit_status, exit_signal, stdout, stderr): self.env = env self.command = command self.subsystem = subsystem self.exit_status = exit_status self.exit_signal = exit_signal self.stdout = stdout self.stderr = stderr if exit_signal: signal, core_dumped, msg, lang = exit_signal reason = 'Process exited with signal %s%s%s' % \ (signal, ': ' + msg if msg else '', ' (core dumped)' if core_dumped else '') else: reason = 'Process exited with non-zero exit status %s' % \ exit_status lang = DEFAULT_LANG super().__init__('Process', exit_status, reason, lang) class SSHCompletedProcess(Record): """Results from running an SSH process This object is returned by the :meth:`run ` method on :class:`SSHClientConnection` when the requested command has finished running. It contains the following fields: ============ ======================================= ================ Field Description Type ============ ======================================= ================ env The environment the client requested str or ``None`` to be set for the process command The command the client requested the str or ``None`` process to execute (if any) subsystem The subsystem the client requested the str or ``None`` process to open (if any) exit_status The exit status returned, or -1 if an int exit signal is sent exit_signal The exit signal sent (if any) in the tuple or ``None`` form of a tuple containing the signal name, a bool for whether a core dump occurred, a message associated with the signal, and the language the message was in stdout The output sent by the process to str or bytes stdout (if not redirected) stderr The output sent by the process to str or bytes stderr (if not redirected) ============ ======================================= ================ """ __slots__ = OrderedDict((('env', None), ('command', None), ('subsystem', None), ('exit_status', None), ('exit_signal', None), ('stdout', None), ('stderr', None))) class SSHProcess: """SSH process handler""" # Pylint doesn't know that all SSHProcess instances will always be # subclasses of SSHStreamSession. # pylint: disable=no-member def __init__(self): self._readers = {} self._send_eof = {} self._writers = {} self._paused_write_streams = set() self._stdin = None self._stdout = None self._stderr = None def __enter__(self): """Allow SSHProcess to be used as a context manager""" return self def __exit__(self, *exc_info): """Automatically close the channel when exiting the context""" self.close() @asyncio.coroutine def __aenter__(self): """Allow SSHProcess to be used as an async context manager""" return self @asyncio.coroutine def __aexit__(self, *exc_info): """Wait for a full channel close when exiting the async context""" self.close() yield from self._chan.wait_closed() @property def channel(self): """The channel associated with the process""" return self._chan @property def env(self): """The environment set by the client for the process This method returns the environment set by the client when the session was opened. :returns: A dictionary containing the environment variables set by the client """ return self._chan.get_environment() @property def command(self): """The command the client requested to execute, if any This method returns the command the client requested to execute when the process was started, if any. If the client did not request that a command be executed, this method will return ``None``. :returns: A str containing the command or ``None`` if no command was specified """ return self._chan.get_command() @property def subsystem(self): """The subsystem the client requested to open, if any This method returns the subsystem the client requested to open when the process was started, if any. If the client did not request that a subsystem be opened, this method will return ``None``. :returns: A str containing the subsystem name or ``None`` if no subsystem was specified """ return self._chan.get_subsystem() @asyncio.coroutine def _create_reader(self, source, bufsize, send_eof, datatype=None): """Create a reader to forward data to the SSH channel""" def pipe_factory(): """Return a pipe read handler""" return _PipeReader(self, datatype, self._encoding) if source == PIPE: reader = None elif source == DEVNULL: self._chan.write_eof() reader = None elif isinstance(source, SSHReader): reader_process, reader_datatype = source.get_redirect_info() writer = _ProcessWriter(self, datatype) reader_process.set_writer(writer, reader_datatype) reader = _ProcessReader(reader_process, reader_datatype) else: if isinstance(source, str): file = open(source, 'rb', buffering=bufsize) elif isinstance(source, int): file = os.fdopen(source, 'rb', buffering=bufsize) elif isinstance(source, socket.socket): file = os.fdopen(source.detach(), 'rb', buffering=bufsize) else: file = source if _is_regular_file(file): reader = _FileReader(self, file, bufsize, datatype, self._encoding) else: if hasattr(source, 'buffer'): # If file was opened in text mode, remove that wrapper file = source.buffer _, reader = \ yield from self._loop.connect_read_pipe(pipe_factory, file) self.set_reader(reader, send_eof, datatype) if isinstance(reader, _FileReader): reader.feed() elif isinstance(reader, _ProcessReader): reader_process.feed_recv_buf(reader_datatype, writer) @asyncio.coroutine def _create_writer(self, target, bufsize, send_eof, datatype=None): """Create a writer to forward data from the SSH channel""" def pipe_factory(): """Return a pipe write handler""" return _PipeWriter(self, datatype, self._encoding) if target == DEVNULL: writer = _DevNullWriter() elif target == PIPE: writer = None elif target == STDOUT: writer = _StdoutWriter(self) elif isinstance(target, SSHWriter): writer_process, writer_datatype = target.get_redirect_info() reader = _ProcessReader(self, datatype) writer_process.set_reader(reader, send_eof, writer_datatype) writer = _ProcessWriter(writer_process, writer_datatype) else: if isinstance(target, str): file = open(target, 'wb', buffering=bufsize) elif isinstance(target, int): file = os.fdopen(target, 'wb', buffering=bufsize) elif isinstance(target, socket.socket): file = os.fdopen(target.detach(), 'wb', buffering=bufsize) else: file = target if _is_regular_file(file): writer = _FileWriter(file, self._encoding) else: if hasattr(target, 'buffer'): # If file was opened in text mode, remove that wrapper file = target.buffer _, writer = \ yield from self._loop.connect_write_pipe(pipe_factory, file) self.set_writer(writer, datatype) if writer: self.feed_recv_buf(datatype, writer) def _should_block_drain(self, datatype): """Return whether output is still being written to the channel""" return (datatype in self._readers or super()._should_block_drain(datatype)) def _should_pause_reading(self): """Return whether to pause reading from the channel""" return self._paused_write_streams or super()._should_pause_reading() def connection_lost(self, exc): """Handle a close of the SSH channel""" super().connection_lost(exc) for reader in self._readers.values(): reader.close() for writer in self._writers.values(): writer.close() self._readers = {} self._writers = {} def data_received(self, data, datatype): """Handle incoming data from the SSH channel""" writer = self._writers.get(datatype) if writer: writer.write(data) else: super().data_received(data, datatype) def eof_received(self): """Handle an incoming end of file from the SSH channel""" for writer in list(self._writers.values()): writer.write_eof() return super().eof_received() def pause_writing(self): """Pause forwarding data to the channel""" super().pause_writing() for reader in self._readers.values(): reader.pause_reading() def resume_writing(self): """Resume forwarding data to the channel""" super().resume_writing() for reader in list(self._readers.values()): reader.resume_reading() def feed_data(self, data, datatype): """Feed data to the channel""" self._chan.write(data, datatype) def feed_eof(self, datatype): """Feed EOF to the channel""" if self._send_eof[datatype]: self._chan.write_eof() self._readers[datatype].close() self.clear_reader(datatype) def feed_recv_buf(self, datatype, writer): """Feed current receive buffer to a newly set writer""" for data in self._recv_buf[datatype]: writer.write(data) self._recv_buf_len -= len(data) self._recv_buf[datatype].clear() if self._eof_received: writer.write_eof() self._maybe_resume_reading() def pause_feeding(self, datatype): """Pause feeding data from the channel""" self._paused_write_streams.add(datatype) self._maybe_pause_reading() def resume_feeding(self, datatype): """Resume feeding data from the channel""" self._paused_write_streams.remove(datatype) self._maybe_resume_reading() def set_reader(self, reader, send_eof, datatype): """Set a reader used to forward data to the channel""" old_reader = self._readers.get(datatype) if old_reader: old_reader.close() if reader: self._readers[datatype] = reader self._send_eof[datatype] = send_eof if self._write_paused: reader.pause_reading() elif old_reader: self.clear_reader(datatype) def clear_reader(self, datatype): """Clear a reader forwarding data to the channel""" del self._readers[datatype] del self._send_eof[datatype] self._unblock_drain(datatype) def set_writer(self, writer, datatype): """Set a writer used to forward data from the channel""" old_writer = self._writers.get(datatype) if old_writer: old_writer.close() self.clear_writer(datatype) if writer: self._writers[datatype] = writer def clear_writer(self, datatype): """Clear a writer forwarding data from the channel""" if datatype in self._paused_write_streams: self.resume_feeding(datatype) del self._writers[datatype] def close(self): """Shut down the process""" self._chan.close() @asyncio.coroutine def wait_closed(self): """Wait for the process to finish shutting down""" yield from self._chan.wait_closed() class SSHClientProcess(SSHProcess, SSHClientStreamSession): """SSH client process handler""" def __init__(self): SSHProcess.__init__(self) SSHClientStreamSession.__init__(self) def _collect_output(self, datatype=None): """Return output from the process""" recv_buf = self._recv_buf[datatype] if recv_buf and isinstance(recv_buf[-1], Exception): recv_buf = recv_buf[:-1] buf = '' if self._encoding else b'' return buf.join(recv_buf) def session_started(self): """Start a process for this newly opened client channel""" self._stdin = SSHWriter(self, self._chan) self._stdout = SSHReader(self, self._chan) self._stderr = SSHReader(self, self._chan, EXTENDED_DATA_STDERR) @property def exit_status(self): """The exit status of the process""" return self._chan.get_exit_status() @property def exit_signal(self): """Exit signal information for the process""" return self._chan.get_exit_signal() @property def stdin(self): """The :class:`SSHWriter` to use to write to stdin of the process""" return self._stdin @property def stdout(self): """The :class:`SSHReader` to use to read from stdout of the process""" return self._stdout @property def stderr(self): """The :class:`SSHReader` to use to read from stderr of the process""" return self._stderr @asyncio.coroutine def redirect(self, stdin=None, stdout=None, stderr=None, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Perform I/O redirection for the process This method redirects data going to or from any or all of standard input, standard output, and standard error for the process. The ``stdin`` argument can be any of the following: * An :class:`SSHReader` object * A file object open for read * An int file descriptor open for read * A connected socket object * A string containing the name of a file or device to open * ``DEVNULL`` to provide no input to standard input * ``PIPE`` to interactively write standard input The ``stdout`` and ``stderr`` arguments can be any of the following: * An :class:`SSHWriter` object * A file object open for write * An int file descriptor open for write * A connected socket object * A string containing the name of a file or device to open * ``DEVNULL`` to discard standard error output * ``PIPE`` to interactively read standard error output The ``stderr`` argument also accepts the value ``STDOUT`` to request that standard error output be delivered to stdout. File objects passed in can be associated with plain files, pipes, sockets, or ttys. The default value of ``None`` means to not change redirection for that stream. :param stdin: Source of data to feed to standard input :param stdout: Target to feed data from standard output to :param stderr: Target to feed data from standard error to :param int bufsize: Buffer size to use when forwarding data from a file :param bool send_eof: Whether or not to send EOF to the channel when redirection is complete, defaulting to ``True``. If set to ``False``, multiple sources can be sequentially fed to the channel. """ if stdin: yield from self._create_reader(stdin, bufsize, send_eof) if stdout: yield from self._create_writer(stdout, bufsize, send_eof) if stderr: yield from self._create_writer(stderr, bufsize, send_eof, EXTENDED_DATA_STDERR) @asyncio.coroutine def redirect_stdin(self, source, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Redirect standard input of the process""" yield from self.redirect(source, None, None, bufsize, send_eof) @asyncio.coroutine def redirect_stdout(self, target, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Redirect standard output of the process""" yield from self.redirect(None, target, None, bufsize, send_eof) @asyncio.coroutine def redirect_stderr(self, target, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Redirect standard error of the process""" yield from self.redirect(None, None, target, bufsize, send_eof) # pylint: disable=redefined-builtin @asyncio.coroutine def communicate(self, input=None): """Send input to and/or collect output from the process This method is a coroutine which optionally provides input to the process and then waits for the process to exit, returning a tuple of the data written to stdout and stderr. :param input: Input data to feed to standard input of the process. Data should be a str if encoding is set, or bytes if not. :type input: str or bytes :returns: A tuple of output to stdout and stderr """ self._limit = None self._maybe_resume_reading() if input: self._chan.write(input) self._chan.write_eof() yield from self._chan.wait_closed() return (self._collect_output(), self._collect_output(EXTENDED_DATA_STDERR)) # pylint: enable=redefined-builtin def change_terminal_size(self, width, height, pixwidth=0, pixheight=0): """Change the terminal window size for this process This method changes the width and height of the terminal associated with this process. :param int width: The width of the terminal in characters :param int height: The height of the terminal in characters :param int pixwidth: (optional) The width of the terminal in pixels :param int pixheight: (optional) The height of the terminal in pixels :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.change_terminal_size(width, height, pixwidth, pixheight) def send_break(self, msec): """Send a break to the process :param int msec: The duration of the break in milliseconds :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.send_break(msec) def send_signal(self, signal): """Send a signal to the process :param str signal: The signal to deliver :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.send_signal(signal) def terminate(self): """Terminate the process :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.terminate() def kill(self): """Forcibly kill the process :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.kill() @asyncio.coroutine def wait(self, check=False): """Wait for process to exit This method is a coroutine which waits for the process to exit. It returns an :class:`SSHCompletedProcess` object with the exit status or signal information and the output sent to stdout and stderr if those are redirected to pipes. If the check argument is set to ``True``, a non-zero exit status from the process with trigger the :exc:`ProcessError` exception to be raised. :param bool check: Whether or not to raise an error on non-zero exit status :returns: :class:`SSHCompletedProcess` :raises: :exc:`ProcessError` if check is set to ``True`` and the process returns a non-zero exit status """ stdout_data, stderr_data = yield from self.communicate() if check and self.exit_status: raise ProcessError(self.env, self.command, self.subsystem, self.exit_status, self.exit_signal, stdout_data, stderr_data) else: return SSHCompletedProcess(self.env, self.command, self.subsystem, self.exit_status, self.exit_signal, stdout_data, stderr_data) class SSHServerProcess(SSHProcess, SSHServerStreamSession): """SSH server process handler""" def __init__(self, process_factory, sftp_factory, allow_scp): SSHProcess.__init__(self) SSHServerStreamSession.__init__(self, self._start_process, sftp_factory, allow_scp) self._process_factory = process_factory def _start_process(self, stdin, stdout, stderr): """Start a new server process""" self._stdin = stdin self._stdout = stdout self._stderr = stderr return self._process_factory(self) @property def stdin(self): """The :class:`SSHReader` to use to read from stdin of the process""" return self._stdin @property def stdout(self): """The :class:`SSHWriter` to use to write to stdout of the process""" return self._stdout @property def stderr(self): """The :class:`SSHWriter` to use to write to stderr of the process""" return self._stderr @asyncio.coroutine def redirect(self, stdin=None, stdout=None, stderr=None, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Perform I/O redirection for the process This method redirects data going to or from any or all of standard input, standard output, and standard error for the process. The ``stdin`` argument can be any of the following: * An :class:`SSHWriter` object * A file object open for write * An int file descriptor open for write * A connected socket object * A string containing the name of a file or device to open * ``DEVNULL`` to discard standard error output * ``PIPE`` to interactively read standard error output The ``stdout`` and ``stderr`` arguments can be any of the following: * An :class:`SSHReader` object * A file object open for read * An int file descriptor open for read * A connected socket object * A string containing the name of a file or device to open * ``DEVNULL`` to provide no input to standard input * ``PIPE`` to interactively write standard input File objects passed in can be associated with plain files, pipes, sockets, or ttys. The default value of ``None`` means to not change redirection for that stream. :param stdin: Target to feed data from standard input to :param stdout: Source of data to feed to standard output :param stderr: Source of data to feed to standard error :param int bufsize: Buffer size to use when forwarding data from a file :param bool send_eof: Whether or not to send EOF to the channel when redirection is complete, defaulting to ``True``. If set to ``False``, multiple sources can be sequentially fed to the channel. """ if stdin: yield from self._create_writer(stdin, bufsize, send_eof) if stdout: yield from self._create_reader(stdout, bufsize, send_eof) if stderr: yield from self._create_reader(stderr, bufsize, send_eof, EXTENDED_DATA_STDERR) @asyncio.coroutine def redirect_stdin(self, target, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Redirect standard input of the process""" yield from self.redirect(target, None, None, bufsize, send_eof) @asyncio.coroutine def redirect_stdout(self, source, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Redirect standard output of the process""" yield from self.redirect(None, source, None, bufsize, send_eof) @asyncio.coroutine def redirect_stderr(self, source, bufsize=io.DEFAULT_BUFFER_SIZE, send_eof=True): """Redirect standard error of the process""" yield from self.redirect(None, None, source, bufsize, send_eof) def get_environment(self): """Return the environment set by the client (deprecated)""" return self.env # pragma: no cover def get_command(self): """Return the command the client requested to execute (deprecated)""" return self.command # pragma: no cover def get_subsystem(self): """Return the subsystem the client requested to open (deprecated)""" return self.subsystem # pragma: no cover def get_terminal_type(self): """Return the terminal type set by the client for the process This method returns the terminal type set by the client when the process was started. If the client didn't request a pseudo-terminal, this method will return ``None``. :returns: A str containing the terminal type or ``None`` if no pseudo-terminal was requested """ return self._chan.get_terminal_type() def get_terminal_size(self): """Return the terminal size set by the client for the process This method returns the latest terminal size information set by the client. If the client didn't set any terminal size information, all values returned will be zero. :returns: A tuple of four integers containing the width and height of the terminal in characters and the width and height of the terminal in pixels """ return self._chan.get_terminal_size() def get_terminal_mode(self, mode): """Return the requested TTY mode for this session This method looks up the value of a POSIX terminal mode set by the client when the process was started. If the client didn't request a pseudo-terminal or didn't set the requested TTY mode opcode, this method will return ``None``. :param int mode: POSIX terminal mode taken from :ref:`POSIX terminal modes ` to look up :returns: An int containing the value of the requested POSIX terminal mode or ``None`` if the requested mode was not set """ return self._chan.get_terminal_mode(mode) def exit(self, status): """Send exit status and close the channel This method can be called to report an exit status for the process back to the client and close the channel. :param int status: The exit status to report to the client """ self._chan.exit(status) def exit_with_signal(self, signal, core_dumped=False, msg='', lang=DEFAULT_LANG): """Send exit signal and close the channel This method can be called to report that the process terminated abnormslly with a signal. A more detailed error message may also provided, along with an indication of whether or not the process dumped core. After reporting the signal, the channel is closed. :param str signal: The signal which caused the process to exit :param bool core_dumped: (optional) Whether or not the process dumped core :param str msg: (optional) Details about what error occurred :param str lang: (optional) The language the error message is in """ return self._chan.exit_with_signal(signal, core_dumped, msg, lang) asyncssh-1.11.1/asyncssh/public_key.py000066400000000000000000003024741320320510200177570ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH asymmetric encryption handlers""" import binascii from datetime import datetime, timedelta import os import re import time try: from .crypto import generate_x509_certificate, import_x509_certificate _x509_available = True except ImportError: # pragma: no cover _x509_available = False try: import bcrypt _bcrypt_available = hasattr(bcrypt, 'kdf') except ImportError: # pragma: no cover _bcrypt_available = False from .asn1 import ASN1DecodeError, BitString, der_encode, der_decode from .cipher import get_encryption_params, get_cipher from .misc import ip_network from .packet import NameList, String, UInt32, UInt64 from .packet import PacketDecodeError, SSHPacket from .pbe import KeyEncryptionError, pkcs1_encrypt, pkcs8_encrypt from .pbe import pkcs1_decrypt, pkcs8_decrypt _public_key_algs = [] _certificate_algs = [] _x509_certificate_algs = [] _public_key_alg_map = {} _certificate_alg_map = {} _certificate_version_map = {} _pem_map = {} _pkcs8_oid_map = {} _abs_date_pattern = re.compile(r'\d{8}') _abs_time_pattern = re.compile(r'\d{14}') _rel_time_pattern = re.compile(r'(?:(?P[+-]?\d+)[Ww]|' r'(?P[+-]?\d+)[Dd]|' r'(?P[+-]?\d+)[Hh]|' r'(?P[+-]?\d+)[Mm]|' r'(?P[+-]?\d+)[Ss])+') _subject_pattern = re.compile(r'(?:Distinguished[ -_]?Name|Subject|DN)[=:]?\s?', re.IGNORECASE) # SSH certificate types CERT_TYPE_USER = 1 CERT_TYPE_HOST = 2 _OPENSSH_KEY_V1 = b'openssh-key-v1\0' _OPENSSH_SALT_LEN = 16 _OPENSSH_WRAP_LEN = 70 def _parse_time(t): """Parse a time value""" if isinstance(t, int): return t elif isinstance(t, float): return int(t) elif isinstance(t, datetime): return int(t.timestamp()) elif isinstance(t, str): if t == 'now': return int(time.time()) match = _abs_date_pattern.fullmatch(t) if match: return int(datetime.strptime(t, '%Y%m%d').timestamp()) match = _abs_time_pattern.fullmatch(t) if match: return int(datetime.strptime(t, '%Y%m%d%H%M%S').timestamp()) match = _rel_time_pattern.fullmatch(t) if match: delta = {k: int(v) for k, v in match.groupdict(0).items()} return int(time.time() + timedelta(**delta).total_seconds()) raise ValueError('Unrecognized time value') def _wrap_base64(data, wrap=64): """Break a Base64 value into multiple lines.""" data = binascii.b2a_base64(data)[:-1] return b'\n'.join(data[i:i+wrap] for i in range(0, len(data), wrap)) + b'\n' class KeyGenerationError(ValueError): """Key generation error This exception is raised by :func:`generate_private_key`, :meth:`generate_user_certificate() ` or :meth:`generate_host_certificate() ` when the requested parameters are unsupported. """ class KeyImportError(ValueError): """Key import error This exception is raised by key import functions when the data provided cannot be imported as a valid key. """ class KeyExportError(ValueError): """Key export error This exception is raised by key export functions when the requested format is unknown or encryption is requested for a format which doesn't support it. """ class SSHKey: """Parent class which holds an asymmetric encryption key""" algorithm = None sig_algorithms = None x509_algorithms = None all_sig_algorithms = None pem_name = None pkcs8_oid = None def __init__(self, key=None): self._key = key self._comment = None @property def pyca_key(self): """Return PyCA key for use in X.509 module""" return self._key.pyca_key def _generate_certificate(self, key, version, serial, cert_type, key_id, principals, valid_after, valid_before, cert_options, comment): """Generate a new SSH certificate""" valid_after = _parse_time(valid_after) valid_before = _parse_time(valid_before) if valid_before <= valid_after: raise ValueError('Valid before time must be later than ' 'valid after time') try: algorithm, cert_handler = _certificate_version_map[key.algorithm, version] except KeyError: raise KeyGenerationError('Unknown certificate version') from None return cert_handler.generate(self, algorithm, key, serial, cert_type, key_id, principals, valid_after, valid_before, cert_options, comment) def _generate_x509_certificate(self, key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_alg, comment): """Generate a new X.509 certificate""" if not _x509_available: # pragma: no cover raise KeyGenerationError('X.509 certificate generation ' 'requires PyOpenSSL') if not self.x509_algorithms: raise KeyGenerationError('X.509 certificate generation not ' 'supported for ' + self.get_algorithm() + ' keys') valid_after = _parse_time(valid_after) valid_before = _parse_time(valid_before) if valid_before <= valid_after: raise ValueError('Valid before time must be later than ' 'valid after time') return SSHX509Certificate.generate(self, key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_alg, comment) def get_algorithm(self): """Return the algorithm associated with this key""" return self.algorithm.decode('ascii') def get_comment(self): """Return the comment associated with this key :returns: `str` or ``None`` """ return self._comment def set_comment(self, comment): """Set the comment associated with this key :param comment: The new comment to associate with this key :type comment: `str` or ``None`` """ if isinstance(comment, bytes): try: comment = comment.decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in comment') from None self._comment = comment or None def sign_der(self, data, sig_algorithm): """Abstract method to compute a DER-encoded signature""" raise NotImplementedError def verify_der(self, data, sig_algorithm, sig): """Abstract method to verify a DER-encoded signature""" raise NotImplementedError def sign_ssh(self, data, sig_algorithm): """Abstract method to compute an SSH-encoded signature""" raise NotImplementedError def verify_ssh(self, data, sig_algorithm, sig): """Abstract method to verify an SSH-encoded signature""" raise NotImplementedError def sign(self, data, sig_algorithm): """Return an SSH-encoded signature of the specified data""" if sig_algorithm not in self.all_sig_algorithms: raise ValueError('Unrecognized signature algorithm') return b''.join((String(sig_algorithm), String(self.sign_ssh(data, sig_algorithm)))) def verify(self, data, sig): """Verify an SSH signature of the specified data using this key""" try: packet = SSHPacket(sig) sig_algorithm = packet.get_string() sig = packet.get_string() packet.check_end() if sig_algorithm not in self.all_sig_algorithms: return False return self.verify_ssh(data, sig_algorithm, sig) except PacketDecodeError: return False def encode_pkcs1_private(self): """Export parameters associated with a PKCS#1 private key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#1 private key export not supported') def encode_pkcs1_public(self): """Export parameters associated with a PKCS#1 public key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#1 public key export not supported') def encode_pkcs8_private(self): """Export parameters associated with a PKCS#8 private key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#8 private key export not supported') def encode_pkcs8_public(self): """Export parameters associated with a PKCS#8 public key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#8 public key export not supported') def encode_ssh_private(self): """Export parameters associated with an OpenSSH private key""" # pylint: disable=no-self-use raise KeyExportError('OpenSSH private key export not supported') def encode_ssh_public(self): """Export parameters associated with an OpenSSH public key""" # pylint: disable=no-self-use raise KeyExportError('OpenSSH public key export not supported') def get_ssh_private_key(self): """Return OpenSSH private key in binary format""" return String(self.algorithm) + self.encode_ssh_private() def get_ssh_public_key(self): """Return OpenSSH public key in binary format""" return String(self.algorithm) + self.encode_ssh_public() def convert_to_public(self): """Return public key corresponding to this key This method converts an :class:`SSHKey` object which contains a private key into one which contains only the corresponding public key. If it is called on something which is already a public key, it has no effect. """ result = decode_ssh_public_key(self.get_ssh_public_key()) result.set_comment(self.get_comment()) return result def generate_user_certificate(self, user_key, key_id, version=1, serial=0, principals=(), valid_after=0, valid_before=0xffffffffffffffff, force_command=None, source_address=None, permit_x11_forwarding=True, permit_agent_forwarding=True, permit_port_forwarding=True, permit_pty=True, permit_user_rc=True, comment=()): """Generate a new SSH user certificate This method returns an SSH user certifcate with the requested attributes signed by this private key. :param user_key: The user's public key. :param str key_id: The key identifier associated with this certificate. :param int version: (optional) The version of certificate to create, defaulting to 1. :param int serial: (optional) The serial number of the certificate, defaulting to 0. :param principals: (optional) The user names this certificate is valid for. By default, it can be used with any user name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param force_command: (optional) The command (if any) to force a session to run when this certificate is used. :param source_address: (optional) A list of source addresses and networks for which the certificate is valid, defaulting to all addresses. :param bool permit_x11_forwarding: (optional) Whether or not to allow this user to use X11 forwarding, defaulting to ``True``. :param bool permit_agent_forwarding: (optional) Whether or not to allow this user to use agent forwarding, defaulting to ``True``. :param bool permit_port_forwarding: (optional) Whether or not to allow this user to use port forwarding, defaulting to ``True``. :param bool permit_pty: (optional) Whether or not to allow this user to allocate a pseudo-terminal, defaulting to ``True``. :param bool permit_user_rc: (optional) Whether or not to run the user rc file when this certificate is used, defaulting to ``True``. :param comment: The comment to associate with this certificate. By default, the comment will be set to the comment currently set on user_key. :type user_key: :class:`SSHKey` :type principals: list of strings :type force_command: `str` or ``None`` :type source_address: list of ip_address and ip_network values :type comment: `str` or ``None`` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ cert_options = {} if force_command: cert_options['force-command'] = force_command if source_address: cert_options['source-address'] = [ip_network(addr) for addr in source_address] if permit_x11_forwarding: cert_options['permit-X11-forwarding'] = True if permit_agent_forwarding: cert_options['permit-agent-forwarding'] = True if permit_port_forwarding: cert_options['permit-port-forwarding'] = True if permit_pty: cert_options['permit-pty'] = True if permit_user_rc: cert_options['permit-user-rc'] = True if comment is (): comment = user_key.get_comment() return self._generate_certificate(user_key, version, serial, CERT_TYPE_USER, key_id, principals, valid_after, valid_before, cert_options, comment) def generate_host_certificate(self, host_key, key_id, version=1, serial=0, principals=(), valid_after=0, valid_before=0xffffffffffffffff, comment=()): """Generate a new SSH host certificate This method returns an SSH host certifcate with the requested attributes signed by this private key. :param host_key: The host's public key. :param str key_id: The key identifier associated with this certificate. :param int version: (optional) The version of certificate to create, defaulting to 1. :param int serial: (optional) The serial number of the certificate, defaulting to 0. :param principals: (optional) The host names this certificate is valid for. By default, it can be used with any host name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param comment: The comment to associate with this certificate. By default, the comment will be set to the comment currently set on host_key. :type host_key: :class:`SSHKey` :type principals: list of strings :type comment: `str` or ``None`` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ if comment is (): comment = host_key.get_comment() return self._generate_certificate(host_key, version, serial, CERT_TYPE_HOST, key_id, principals, valid_after, valid_before, {}, comment) def generate_x509_user_certificate(self, user_key, subject, issuer=None, serial=None, principals=(), valid_after=0, valid_before=0xffffffffffffffff, purposes='secureShellClient', hash_alg='sha256', comment=()): """Generate a new X.509 user certificate This method returns an X.509 user certifcate with the requested attributes signed by this private key. :param user_key: The user's public key. :param str subject: The subject name in the certificate, expresed as a comma-separated list of X.509 ``name=value`` pairs. :param str issuer: (optional) The issuer name in the certificate, expresed as a comma-separated list of X.509 ``name=value`` pairs. If not specified, the subject name will be used, creating a self-signed certificate. :param int serial: (optional) The serial number of the certificate, defaulting to a random 64-bit value. :param principals: (optional) The user names this certificate is valid for. By default, it can be used with any user name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param purposes: (optional) The allowed purposes for this certificate or ``None`` to not restrict the certificate's purpose, defaulting to 'secureShellClient' :param str hash_alg: (optional) The hash algorithm to use when signing the new certificate, defaulting to SHA256. :param comment: (optional) The comment to associate with this certificate. By default, the comment will be set to the comment currently set on user_key. :type user_key: :class:`SSHKey` :type principals: list of strings :type purposes: list of strings or ``None`` :type comment: `str` or ``None`` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ if comment is (): comment = user_key.get_comment() return self._generate_x509_certificate(user_key, subject, issuer, serial, valid_after, valid_before, False, None, purposes, principals, (), hash_alg, comment) def generate_x509_host_certificate(self, host_key, subject, issuer=None, serial=None, principals=(), valid_after=0, valid_before=0xffffffffffffffff, purposes='secureShellServer', hash_alg='sha256', comment=()): """Generate a new X.509 host certificate This method returns a X.509 host certifcate with the requested attributes signed by this private key. :param host_key: The host's public key. :param str subject: The subject name in the certificate, expresed as a comma-separated list of X.509 ``name=value`` pairs. :param str issuer: (optional) The issuer name in the certificate, expresed as a comma-separated list of X.509 ``name=value`` pairs. If not specified, the subject name will be used, creating a self-signed certificate. :param int serial: (optional) The serial number of the certificate, defaulting to a random 64-bit value. :param principals: (optional) The host names this certificate is valid for. By default, it can be used with any host name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param purposes: (optional) The allowed purposes for this certificate or ``None`` to not restrict the certificate's purpose, defaulting to 'secureShellServer' :param str hash_alg: (optional) The hash algorithm to use when signing the new certificate, defaulting to SHA256. :param comment: (optional) The comment to associate with this certificate. By default, the comment will be set to the comment currently set on host_key. :type host_key: :class:`SSHKey` :type principals: list of strings :type purposes: list of strings or ``None`` :type comment: `str` or ``None`` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ if comment is (): comment = host_key.get_comment() return self._generate_x509_certificate(host_key, subject, issuer, serial, valid_after, valid_before, False, None, purposes, (), principals, hash_alg, comment) def generate_x509_ca_certificate(self, ca_key, subject, issuer=None, serial=None, valid_after=0, valid_before=0xffffffffffffffff, ca_path_len=None, hash_alg='sha256', comment=()): """Generate a new X.509 CA certificate This method returns a X.509 CA certifcate with the requested attributes signed by this private key. :param ca_key: The new CA's public key. :param str subject: The subject name in the certificate, expresed as a comma-separated list of X.509 ``name=value`` pairs. :param str issuer: (optional) The issuer name in the certificate, expresed as a comma-separated list of X.509 ``name=value`` pairs. If not specified, the subject name will be used, creating a self-signed certificate. :param int serial: (optional) The serial number of the certificate, defaulting to a random 64-bit value. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param ca_path_len: (optional) The maximum number of levels of intermediate CAs allowed below this new CA or ``None`` to not enforce a limit, defaulting to no limit. :param str hash_alg: (optional) The hash algorithm to use when signing the new certificate, defaulting to SHA256. :param comment: (optional) The comment to associate with this certificate. By default, the comment will be set to the comment currently set on ca_key. :type ca_key: :class:`SSHKey` :type ca_path_len: `int` or ``None`` :type comment: `str` or ``None`` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ if comment is (): comment = ca_key.get_comment() return self._generate_x509_certificate(ca_key, subject, issuer, serial, valid_after, valid_before, True, ca_path_len, None, (), (), hash_alg, comment) def export_private_key(self, format_name='openssh', passphrase=None, cipher_name='aes256-cbc', hash_name='sha256', pbe_version=2, rounds=128): """Export a private key in the requested format This method returns this object's private key encoded in the requested format. If a passphrase is specified, the key will be exported in encrypted form. Available formats include: pkcs1-der, pkcs1-pem, pkcs8-der, pkcs8-pem, openssh By default, openssh format will be used. Encryption is supported in pkcs1-pem, pkcs8-der, pkcs8-pem, and openssh formats. For pkcs1-pem, only the cipher can be specified. For pkcs8-der and pkcs-8, cipher, hash and PBE version can be specified. For openssh, cipher and rounds can be specified. Available ciphers for pkcs1-pem are: aes128-cbc, aes192-cbc, aes256-cbc, des-cbc, des3-cbc Available ciphers for pkcs8-der and pkcs8-pem are: aes128-cbc, aes192-cbc, aes256-cbc, blowfish-cbc, cast128-cbc, des-cbc, des2-cbc, des3-cbc, rc4-40, rc4-128 Available ciphers for openssh format include the following :ref:`encryption algorithms `. Available hashes include: md5, sha1, sha256, sha384, sha512 Available PBE versions include 1 for PBES1 and 2 for PBES2. Not all combinations of cipher, hash, and version are supported. The default cipher is aes256. In the pkcs8 formats, the default hash is sha256 and default version is PBES2. In openssh format, the default number of rounds is 128. .. note:: The openssh format uses bcrypt for encryption, but unlike the traditional bcrypt cost factor used in password hashing which scales logarithmically, the encryption strength here scales linearly with the rounds value. Since the cipher is rekeyed 64 times per round, the default rounds value of 128 corresponds to 8192 total iterations, which is the equivalent of a bcrypt cost factor of 13. :param str format_name: (optional) The format to export the key in. :param str passphrase: (optional) A passphrase to encrypt the private key with. :param str cipher_name: (optional) The cipher to use for private key encryption. :param str hash_name: (optional) The hash to use for private key encryption. :param int pbe_version: (optional) The PBE version to use for private key encryption. :param int rounds: (optional) The number of KDF rounds to apply to the passphrase. :returns: bytes representing the exported private key """ if format_name in ('pkcs1-der', 'pkcs1-pem'): data = der_encode(self.encode_pkcs1_private()) if passphrase is not None: if format_name == 'pkcs1-der': raise KeyExportError('PKCS#1 DER format does not support ' 'private key encryption') alg, iv, data = pkcs1_encrypt(data, cipher_name, passphrase) headers = (b'Proc-Type: 4,ENCRYPTED\n' + b'DEK-Info: ' + alg + b',' + binascii.b2a_hex(iv).upper() + b'\n\n') else: headers = b'' if format_name == 'pkcs1-pem': keytype = self.pem_name + b' PRIVATE KEY' data = (b'-----BEGIN ' + keytype + b'-----\n' + headers + _wrap_base64(data) + b'-----END ' + keytype + b'-----\n') return data elif format_name in ('pkcs8-der', 'pkcs8-pem'): alg_params, data = self.encode_pkcs8_private() data = der_encode((0, (self.pkcs8_oid, alg_params), data)) if passphrase is not None: data = pkcs8_encrypt(data, cipher_name, hash_name, pbe_version, passphrase) if format_name == 'pkcs8-pem': if passphrase is not None: keytype = b'ENCRYPTED PRIVATE KEY' else: keytype = b'PRIVATE KEY' data = (b'-----BEGIN ' + keytype + b'-----\n' + _wrap_base64(data) + b'-----END ' + keytype + b'-----\n') return data elif format_name == 'openssh': check = os.urandom(4) nkeys = 1 data = b''.join((check, check, self.get_ssh_private_key(), String(self.get_comment() or ''))) if passphrase is not None: try: alg = cipher_name.encode('ascii') key_size, iv_size, block_size, mode = \ get_encryption_params(alg) except (KeyError, UnicodeEncodeError): raise KeyEncryptionError('Unknown cipher: %s' % cipher_name) from None if not _bcrypt_available: # pragma: no cover raise KeyExportError('OpenSSH private key encryption ' 'requires bcrypt with KDF support') kdf = b'bcrypt' salt = os.urandom(_OPENSSH_SALT_LEN) kdf_data = b''.join((String(salt), UInt32(rounds))) if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') # pylint: disable=no-member key = bcrypt.kdf(passphrase, salt, key_size + iv_size, rounds) # pylint: enable=no-member cipher = get_cipher(alg, key[:key_size], key[key_size:]) block_size = max(block_size, 8) else: cipher = None alg = b'none' kdf = b'none' kdf_data = b'' block_size = 8 mac = b'' pad = len(data) % block_size if pad: # pragma: no branch data = data + bytes(range(1, block_size + 1 - pad)) if cipher: if mode == 'chacha': data, mac = cipher.encrypt_and_sign(b'', data, UInt64(0)) elif mode == 'gcm': data, mac = cipher.encrypt_and_sign(b'', data) else: data, mac = cipher.encrypt(data), b'' data = b''.join((_OPENSSH_KEY_V1, String(alg), String(kdf), String(kdf_data), UInt32(nkeys), String(self.get_ssh_public_key()), String(data), mac)) return (b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + _wrap_base64(data, _OPENSSH_WRAP_LEN) + b'-----END OPENSSH PRIVATE KEY-----\n') else: raise KeyExportError('Unknown export format') def export_public_key(self, format_name='openssh'): """Export a public key in the requested format This method returns this object's public key encoded in the requested format. Available formats include: pkcs1-der, pkcs1-pem, pkcs8-der, pkcs8-pem, openssh, rfc4716 By default, openssh format will be used. :param str format_name: (optional) The format to export the key in. :returns: bytes representing the exported public key """ if format_name in ('pkcs1-der', 'pkcs1-pem'): data = der_encode(self.encode_pkcs1_public()) if format_name == 'pkcs1-pem': keytype = self.pem_name + b' PUBLIC KEY' data = (b'-----BEGIN ' + keytype + b'-----\n' + _wrap_base64(data) + b'-----END ' + keytype + b'-----\n') return data elif format_name in ('pkcs8-der', 'pkcs8-pem'): alg_params, data = self.encode_pkcs8_public() data = der_encode(((self.pkcs8_oid, alg_params), BitString(data))) if format_name == 'pkcs8-pem': data = (b'-----BEGIN PUBLIC KEY-----\n' + _wrap_base64(data) + b'-----END PUBLIC KEY-----\n') return data elif format_name == 'openssh': data = self.get_ssh_public_key() if self._comment: comment = b' ' + self._comment.encode('utf-8') else: comment = b'' return (self.algorithm + b' ' + binascii.b2a_base64(data)[:-1] + comment + b'\n') elif format_name == 'rfc4716': data = self.get_ssh_public_key() if self._comment: comment = (b'Comment: "' + self._comment.encode('utf-8') + b'"\n') else: comment = b'' return (b'---- BEGIN SSH2 PUBLIC KEY ----\n' + comment + _wrap_base64(data) + b'---- END SSH2 PUBLIC KEY ----\n') else: raise KeyExportError('Unknown export format') def write_private_key(self, filename, *args, **kwargs): """Write a private key to a file in the requested format This method is a simple wrapper around :meth:`export_private_key` which writes the exported key data to a file. :param str filename: The filename to write the private key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_private_key`. """ with open(filename, 'wb') as f: f.write(self.export_private_key(*args, **kwargs)) def write_public_key(self, filename, *args, **kwargs): """Write a public key to a file in the requested format This method is a simple wrapper around :meth:`export_public_key` which writes the exported key data to a file. :param str filename: The filename to write the public key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_public_key`. """ with open(filename, 'wb') as f: f.write(self.export_public_key(*args, **kwargs)) def append_private_key(self, filename, *args, **kwargs): """Append a private key to a file in the requested format This method is a simple wrapper around :meth:`export_private_key` which appends the exported key data to an existing file. :param str filename: The filename to append the private key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_private_key`. """ with open(filename, 'ab') as f: f.write(self.export_private_key(*args, **kwargs)) def append_public_key(self, filename, *args, **kwargs): """Append a public key to a file in the requested format This method is a simple wrapper around :meth:`export_public_key` which appends the exported key data to an existing file. :param str filename: The filename to append the public key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_public_key`. """ with open(filename, 'ab') as f: f.write(self.export_public_key(*args, **kwargs)) class SSHCertificate: """Parent class which holds an SSH certificate""" is_x509 = False is_x509_chain = False def __init__(self, algorithm, sig_algorithms, host_key_algorithms, key, data, comment): self.algorithm = algorithm self.sig_algorithms = sig_algorithms self.host_key_algorithms = host_key_algorithms self.key = key self.data = data self.set_comment(comment) def __eq__(self, other): return isinstance(other, type(self)) and self.data == other.data def __hash__(self): return hash(self.data) def get_algorithm(self): """Return the algorithm associated with this certificate""" return self.algorithm.decode('ascii') def get_comment(self): """Return the comment associated with this certificate :returns: `str` or ``None`` """ return self._comment def set_comment(self, comment): """Set the comment associated with this certificate :param comment: The new comment to associate with this certificate :type comment: `str` or ``None`` """ if isinstance(comment, bytes): try: comment = comment.decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in comment') from None self._comment = comment or None def export_certificate(self, format_name='openssh'): """Export a certificate in the requested format This function returns this certificate encoded in the requested format. Available formats include: der, pem, openssh, rfc4716 By default, OpenSSH format will be used. :param str format_name: (optional) The format to export the certificate in. :returns: bytes representing the exported certificate """ if self.is_x509: if format_name == 'rfc4716': raise KeyExportError('RFC4716 format is not supported for ' 'X.509 certificates') else: if format_name in ('der', 'pem'): raise KeyExportError('DER and PEM formats are not supported ' 'for OpenSSH certificates') if format_name == 'der': return self.data elif format_name == 'pem': return (b'-----BEGIN CERTIFICATE-----\n' + _wrap_base64(self.data) + b'-----END CERTIFICATE-----\n') elif format_name == 'openssh': if self._comment: comment = b' ' + self._comment.encode('utf-8') else: comment = b'' return (self.algorithm + b' ' + binascii.b2a_base64(self.data)[:-1] + comment + b'\n') elif format_name == 'rfc4716': if self._comment: comment = (b'Comment: "' + self._comment.encode('utf-8') + b'"\n') else: comment = b'' return (b'---- BEGIN SSH2 PUBLIC KEY ----\n' + comment + _wrap_base64(self.data) + b'---- END SSH2 PUBLIC KEY ----\n') else: raise KeyExportError('Unknown export format') def write_certificate(self, filename, *args, **kwargs): """Write a certificate to a file in the requested format This function is a simple wrapper around export_certificate which writes the exported certificate to a file. :param str filename: The filename to write the certificate to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_certificate`. """ with open(filename, 'wb') as f: f.write(self.export_certificate(*args, **kwargs)) def append_certificate(self, filename, *args, **kwargs): """Append a certificate to a file in the requested format This function is a simple wrapper around export_certificate which appends the exported certificate to an existing file. :param str filename: The filename to append the certificate to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_certificate`. """ with open(filename, 'ab') as f: f.write(self.export_certificate(*args, **kwargs)) class SSHOpenSSHCertificate(SSHCertificate): """Class which holds an OpenSSH certificate""" _user_option_encoders = [] _user_extension_encoders = [] _host_option_encoders = [] _host_extension_encoders = [] _user_option_decoders = {} _user_extension_decoders = {} _host_option_decoders = {} _host_extension_decoders = {} def __init__(self, algorithm, key, data, principals, options, signing_key, serial, cert_type, key_id, valid_after, valid_before, comment): super().__init__(algorithm, key.sig_algorithms, (algorithm,), key, data, comment) self.principals = principals self.options = options self.signing_key = signing_key self._serial = serial self._cert_type = cert_type self._key_id = key_id self._valid_after = valid_after self._valid_before = valid_before @classmethod def generate(cls, signing_key, algorithm, key, serial, cert_type, key_id, principals, valid_after, valid_before, options, comment): """Generate a new SSH certificate""" principals = list(principals) cert_principals = b''.join(String(p) for p in principals) if cert_type == CERT_TYPE_USER: cert_options = cls._encode_options(options, cls._user_option_encoders) cert_extensions = cls._encode_options(options, cls._user_extension_encoders) else: cert_options = cls._encode_options(options, cls._host_option_encoders) cert_extensions = cls._encode_options(options, cls._host_extension_encoders) key = key.convert_to_public() data = b''.join((String(algorithm), cls._encode(key, serial, cert_type, key_id, cert_principals, valid_after, valid_before, cert_options, cert_extensions), String(signing_key.get_ssh_public_key()))) data += String(signing_key.sign(data, signing_key.algorithm)) signing_key = signing_key.convert_to_public() return cls(algorithm, key, data, principals, options, signing_key, serial, cert_type, key_id, valid_after, valid_before, comment) @classmethod def construct(cls, packet, algorithm, key_handler, comment): """Construct an SSH certificate from packetized data""" key_params, serial, cert_type, key_id, \ principals, valid_after, valid_before, \ options, extensions = cls._decode(packet, key_handler) signing_key = decode_ssh_public_key(packet.get_string()) data = packet.get_consumed_payload() signature = packet.get_string() packet.check_end() if not signing_key.verify(data, signature): raise KeyImportError('Invalid certificate signature') key = key_handler.make_public(*key_params) data = packet.get_consumed_payload() try: key_id = key_id.decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in key ID') packet = SSHPacket(principals) principals = [] while packet: try: principal = packet.get_string().decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in principal name') principals.append(principal) if cert_type == CERT_TYPE_USER: options = cls._decode_options(options, cls._user_option_decoders, True) options.update(cls._decode_options(extensions, cls._user_extension_decoders, False)) elif cert_type == CERT_TYPE_HOST: options = cls._decode_options(options, cls._host_option_decoders, True) options.update(cls._decode_options(extensions, cls._host_extension_decoders, False)) else: raise KeyImportError('Unknown certificate type') return cls(algorithm, key, data, principals, options, signing_key, serial, cert_type, key_id, valid_after, valid_before, comment) @classmethod def _encode(cls, key, serial, cert_type, key_id, principals, valid_after, valid_before, options, extensions): """Encode an SSH certificate""" raise NotImplementedError @classmethod def _decode(cls, packet, key_handler): """Decode an SSH certificate""" raise NotImplementedError @staticmethod def _encode_options(options, encoders): """Encode options found in this certificate""" result = [] for name, encoder in encoders: value = options.get(name) if value: result.append(String(name) + String(encoder(value))) return b''.join(result) @staticmethod def _encode_bool(value): """Encode a boolean option value""" # pylint: disable=unused-argument return b'' @staticmethod def _encode_force_cmd(force_command): """Encode a force-command option""" return String(force_command) @staticmethod def _encode_source_addr(source_address): """Encode a source-address option""" return NameList(str(addr).encode('ascii') for addr in source_address) @staticmethod def _decode_bool(packet): """Decode a boolean option value""" # pylint: disable=unused-argument return True @staticmethod def _decode_force_cmd(packet): """Decode a force-command option""" try: return packet.get_string().decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in command') from None @staticmethod def _decode_source_addr(packet): """Decode a source-address option""" try: return [ip_network(addr.decode('ascii')) for addr in packet.get_namelist()] except (UnicodeDecodeError, ValueError): raise KeyImportError('Invalid source address') from None @staticmethod def _decode_options(options, decoders, critical=True): """Decode options found in this certificate""" packet = SSHPacket(options) result = {} while packet: name = packet.get_string() decoder = decoders.get(name) if decoder: data_packet = SSHPacket(packet.get_string()) result[name.decode('ascii')] = decoder(data_packet) data_packet.check_end() elif critical: raise KeyImportError('Unrecognized critical option: %s' % name.decode('ascii', errors='replace')) return result def validate(self, cert_type, principal): """Validate an OpenSSH certificate""" if self._cert_type != cert_type: raise ValueError('Invalid certificate type') now = time.time() if now < self._valid_after: raise ValueError('Certificate not yet valid') if now >= self._valid_before: raise ValueError('Certificate expired') if principal and self.principals and principal not in self.principals: raise ValueError('Certificate principal mismatch') class SSHOpenSSHCertificateV01(SSHOpenSSHCertificate): """Encoder/decoder class for version 01 OpenSSH certificates""" # pylint: disable=bad-whitespace _user_option_encoders = ( ('force-command', SSHOpenSSHCertificate._encode_force_cmd), ('source-address', SSHOpenSSHCertificate._encode_source_addr) ) _user_extension_encoders = ( ('permit-X11-forwarding', SSHOpenSSHCertificate._encode_bool), ('permit-agent-forwarding', SSHOpenSSHCertificate._encode_bool), ('permit-port-forwarding', SSHOpenSSHCertificate._encode_bool), ('permit-pty', SSHOpenSSHCertificate._encode_bool), ('permit-user-rc', SSHOpenSSHCertificate._encode_bool) ) _user_option_decoders = { b'force-command': SSHOpenSSHCertificate._decode_force_cmd, b'source-address': SSHOpenSSHCertificate._decode_source_addr } _user_extension_decoders = { b'permit-X11-forwarding': SSHOpenSSHCertificate._decode_bool, b'permit-agent-forwarding': SSHOpenSSHCertificate._decode_bool, b'permit-port-forwarding': SSHOpenSSHCertificate._decode_bool, b'permit-pty': SSHOpenSSHCertificate._decode_bool, b'permit-user-rc': SSHOpenSSHCertificate._decode_bool } # pylint: enable=bad-whitespace @classmethod def _encode(cls, key, serial, cert_type, key_id, principals, valid_after, valid_before, options, extensions): """Encode a version 01 SSH certificate""" return b''.join((String(os.urandom(32)), key.encode_ssh_public(), UInt64(serial), UInt32(cert_type), String(key_id), String(principals), UInt64(valid_after), UInt64(valid_before), String(options), String(extensions), String(''))) @classmethod def _decode(cls, packet, key_handler): """Decode a version 01 SSH certificate""" _ = packet.get_string() # nonce key_params = key_handler.decode_ssh_public(packet) serial = packet.get_uint64() cert_type = packet.get_uint32() key_id = packet.get_string() principals = packet.get_string() valid_after = packet.get_uint64() valid_before = packet.get_uint64() options = packet.get_string() extensions = packet.get_string() _ = packet.get_string() # reserved return (key_params, serial, cert_type, key_id, principals, valid_after, valid_before, options, extensions) class SSHX509Certificate(SSHCertificate): """Encoder/decoder class for SSH X.509 certificates""" is_x509 = True def __init__(self, key, x509_cert): super().__init__(b'x509v3-' + key.algorithm, key.x509_algorithms, key.x509_algorithms, key, x509_cert.data, x509_cert.comment) self.subject = x509_cert.subject self.issuer = x509_cert.issuer self.issuer_hash = x509_cert.issuer_hash self.user_principals = x509_cert.user_principals self.x509_cert = x509_cert def _expand_trust_store(self, cert, trusted_cert_paths, trust_store): """Look up certificates by issuer hash to build a trust store""" issuer_hash = cert.issuer_hash for path in trusted_cert_paths: idx = 0 try: while True: cert_path = os.path.join(path, issuer_hash + '.' + str(idx)) idx += 1 c = read_certificate(cert_path) if c.subject != cert.issuer or c in trust_store: continue trust_store.add(c) self._expand_trust_store(c, trusted_cert_paths, trust_store) except (OSError, KeyImportError): pass @classmethod def generate(cls, signing_key, key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_alg, comment): """Generate a new X.509 certificate""" key = key.convert_to_public() x509_cert = generate_x509_certificate(signing_key, key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_alg, comment) return cls(key, x509_cert) @classmethod def construct(cls, data): """Construct an SSH X.509 certificate from DER data""" try: x509_cert = import_x509_certificate(data) key = import_public_key(x509_cert.key_data) except ValueError as exc: raise KeyImportError(str(exc)) from None return cls(key, x509_cert) def validate_chain(self, trust_chain, trusted_certs, trusted_cert_paths, purposes, user_principal=None, host_principal=None): """Validate an X.509 certificate chain""" trust_chain = set(c for c in trust_chain if c.subject != c.issuer) trust_store = trust_chain | set(c for c in trusted_certs) if trusted_cert_paths: self._expand_trust_store(self, trusted_cert_paths, trust_store) for c in trust_chain: self._expand_trust_store(c, trusted_cert_paths, trust_store) self.x509_cert.validate([c.x509_cert for c in trust_store], purposes, user_principal, host_principal) class SSHX509CertificateChain(SSHCertificate): """Encoder/decoder class for an SSH X.509 certificate chain""" is_x509_chain = True def __init__(self, algorithm, data, certs, ocsp_responses, comment): key = certs[0].key super().__init__(algorithm, key.x509_algorithms, key.x509_algorithms, key, data, comment) self.subject = certs[0].subject self.issuer = certs[-1].issuer self.user_principals = certs[0].user_principals self._certs = certs self._ocsp_responses = ocsp_responses @classmethod def construct(cls, packet, algorithm, key_handler, comment=None): """Construct an SSH X.509 certificate from packetized data""" # pylint: disable=unused-argument cert_count = packet.get_uint32() certs = [import_certificate(packet.get_string()) for _ in range(cert_count)] ocsp_resp_count = packet.get_uint32() ocsp_responses = [packet.get_string() for _ in range(ocsp_resp_count)] packet.check_end() data = packet.get_consumed_payload() if not certs: raise KeyImportError('No certificates present') return cls(algorithm, data, certs, ocsp_responses, comment) @classmethod def construct_from_certs(cls, certs): """Construct an SSH X.509 certificate chain from certificates""" cert = certs[0] algorithm = cert.algorithm data = (String(algorithm) + UInt32(len(certs)) + b''.join(String(c.data) for c in certs) + UInt32(0)) return cls(algorithm, data, certs, (), cert.get_comment()) def validate_chain(self, trusted_certs, trusted_cert_paths, revoked_certs, purposes, user_principal=None, host_principal=None): """Validate an X.509 certificate chain""" if revoked_certs: for cert in self._certs: if cert in revoked_certs: raise ValueError('Revoked X.509 certificate in ' 'certificate chain') self._certs[0].validate_chain(self._certs[1:], trusted_certs, trusted_cert_paths, purposes, user_principal, host_principal) class SSHKeyPair: """Parent class which represents an asymmetric key pair This is an abstract class which provides a method to sign data with a private key and members to access the corresponding algorithm and public key or certificate information needed to identify what key was used for signing. """ _key_type = 'unknown' def __init__(self, algorithm, comment): self.algorithm = algorithm self.set_comment(comment) def get_key_type(self): """Return what type of key pair this is This method returns 'local' for locally loaded keys, and 'agent' for keys managed by an SSH agent. """ return self._key_type def get_algorithm(self): """Return the algorithm associated with this key pair""" return self.algorithm.decode('ascii') def get_comment(self): """Return the comment associated with this key pair :returns: `str` or ``None`` """ return self._comment def set_comment(self, comment): """Set the comment associated with this key pair :param comment: The new comment to associate with this key pair :type comment: `str` or ``None`` """ if isinstance(comment, bytes): try: comment = comment.decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in comment') from None self._comment = comment or None def set_sig_algorithm(self, sig_algorithm): """Set the signature algorithm to use when signing data""" raise NotImplementedError def sign(self, data): """Sign a block of data with this private key""" raise NotImplementedError class SSHLocalKeyPair(SSHKeyPair): """Class which holds a local asymmetric key pair This class holds a private key and associated public data which can either be the matching public key or a certificate which has signed that public key. """ _key_type = 'local' def __init__(self, key, cert=None): super().__init__(cert.algorithm if cert else key.algorithm, key.get_comment()) self._key = key self._cert = cert self.sig_algorithm = key.algorithm if cert: if key.get_ssh_public_key() != cert.key.get_ssh_public_key(): raise ValueError('Certificate key mismatch') self.sig_algorithms = cert.sig_algorithms self.host_key_algorithms = cert.host_key_algorithms self.public_data = cert.data else: self.sig_algorithms = key.sig_algorithms self.host_key_algorithms = key.sig_algorithms self.public_data = key.get_ssh_public_key() def get_agent_private_key(self): """Return binary encoding of keypair for upload to SSH agent""" if self._cert: data = String(self.public_data) + \ self._key.encode_agent_cert_private() else: data = self._key.encode_ssh_private() return String(self.algorithm) + data def set_sig_algorithm(self, sig_algorithm): """Set the signature algorithm to use when signing data""" if sig_algorithm.startswith(b'x509v3-'): sig_algorithm = sig_algorithm[7:] self.sig_algorithm = sig_algorithm if not self._cert: self.algorithm = sig_algorithm elif self._cert.algorithm.startswith(b'x509v3-'): self.algorithm = b'x509v3-' + sig_algorithm def sign(self, data): """Sign a block of data with this private key""" return self._key.sign(data, self.sig_algorithm) def _decode_pkcs1_private(pem_name, key_data): """Decode a PKCS#1 format private key""" handler = _pem_map.get(pem_name) if handler is None: raise KeyImportError('Unknown PEM key type: %s' % pem_name.decode('ascii')) key_params = handler.decode_pkcs1_private(key_data) if key_params is None: raise KeyImportError('Invalid %s private key' % pem_name.decode('ascii')) return handler.make_private(*key_params) def _decode_pkcs1_public(pem_name, key_data): """Decode a PKCS#1 format public key""" handler = _pem_map.get(pem_name) if handler is None: raise KeyImportError('Unknown PEM key type: %s' % pem_name.decode('ascii')) key_params = handler.decode_pkcs1_public(key_data) if key_params is None: raise KeyImportError('Invalid %s public key' % pem_name.decode('ascii')) return handler.make_public(*key_params) def _decode_pkcs8_private(key_data): """Decode a PKCS#8 format private key""" if (isinstance(key_data, tuple) and len(key_data) >= 3 and key_data[0] in (0, 1) and isinstance(key_data[1], tuple) and len(key_data[1]) == 2 and isinstance(key_data[2], bytes)): alg, alg_params = key_data[1] handler = _pkcs8_oid_map.get(alg) if handler is None: raise KeyImportError('Unknown PKCS#8 algorithm') key_params = handler.decode_pkcs8_private(alg_params, key_data[2]) if key_params is None: raise KeyImportError('Invalid %s private key' % handler.pem_name.decode('ascii')) return handler.make_private(*key_params) else: raise KeyImportError('Invalid PKCS#8 private key') def _decode_pkcs8_public(key_data): """Decode a PKCS#8 format public key""" if (isinstance(key_data, tuple) and len(key_data) == 2 and isinstance(key_data[0], tuple) and len(key_data[0]) == 2 and isinstance(key_data[1], BitString) and key_data[1].unused == 0): alg, alg_params = key_data[0] handler = _pkcs8_oid_map.get(alg) if handler is None: raise KeyImportError('Unknown PKCS#8 algorithm') key_params = handler.decode_pkcs8_public(alg_params, key_data[1].value) if key_params is None: raise KeyImportError('Invalid %s public key' % handler.pem_name.decode('ascii')) return handler.make_public(*key_params) else: raise KeyImportError('Invalid PKCS#8 public key') def _decode_openssh_private(data, passphrase): """Decode an OpenSSH format private key""" try: if not data.startswith(_OPENSSH_KEY_V1): raise KeyImportError('Unrecognized OpenSSH private key type') data = data[len(_OPENSSH_KEY_V1):] packet = SSHPacket(data) cipher_name = packet.get_string() kdf = packet.get_string() kdf_data = packet.get_string() nkeys = packet.get_uint32() _ = packet.get_string() # public_key key_data = packet.get_string() mac = packet.get_remaining_payload() if nkeys != 1: raise KeyImportError('Invalid OpenSSH private key') if cipher_name != b'none': if passphrase is None: raise KeyImportError('Passphrase must be specified to import ' 'encrypted private keys') try: key_size, iv_size, block_size, mode = \ get_encryption_params(cipher_name) except KeyError: raise KeyEncryptionError('Unknown cipher: %s' % cipher_name.decode('ascii')) from None if kdf != b'bcrypt': raise KeyEncryptionError('Unknown kdf: %s' % kdf.decode('ascii')) if not _bcrypt_available: # pragma: no cover raise KeyEncryptionError('OpenSSH private key encryption ' 'requires bcrypt with KDF support') packet = SSHPacket(kdf_data) salt = packet.get_string() rounds = packet.get_uint32() packet.check_end() if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') try: # pylint: disable=no-member key = bcrypt.kdf(passphrase, salt, key_size + iv_size, rounds) # pylint: enable=no-member except ValueError: raise KeyEncryptionError('Invalid OpenSSH ' 'private key') from None cipher = get_cipher(cipher_name, key[:key_size], key[key_size:]) if mode == 'chacha': key_data = cipher.verify_and_decrypt(b'', key_data, UInt64(0), mac) mac = b'' elif mode == 'gcm': key_data = cipher.verify_and_decrypt(b'', key_data, mac) mac = b'' else: key_data = cipher.decrypt(key_data) if key_data is None: raise KeyEncryptionError('Incorrect passphrase') block_size = max(block_size, 8) else: block_size = 8 if mac: raise KeyImportError('Invalid OpenSSH private key') packet = SSHPacket(key_data) check1 = packet.get_uint32() check2 = packet.get_uint32() if check1 != check2: if cipher_name != b'none': raise KeyEncryptionError('Incorrect passphrase') from None else: raise KeyImportError('Invalid OpenSSH private key') alg = packet.get_string() handler = _public_key_alg_map.get(alg) if not handler: raise KeyImportError('Unknown OpenSSH private key algorithm') key_params = handler.decode_ssh_private(packet) comment = packet.get_string() pad = packet.get_remaining_payload() if len(pad) >= block_size or pad != bytes(range(1, len(pad) + 1)): raise KeyImportError('Invalid OpenSSH private key') key = handler.make_private(*key_params) key.set_comment(comment) return key except PacketDecodeError: raise KeyImportError('Invalid OpenSSH private key') def _decode_der_private(data, passphrase): """Decode a DER format private key""" try: # pylint: disable=unpacking-non-sequence key_data, end = der_decode(data, partial_ok=True) # pylint: enable=unpacking-non-sequence except ASN1DecodeError: raise KeyImportError('Invalid DER private key') from None # First, if there's a passphrase, try to decrypt PKCS#8 if passphrase is not None: try: key_data = pkcs8_decrypt(key_data, passphrase) except KeyEncryptionError: # Decryption failed - try decoding it as unencrypted pass # Then, try to decode PKCS#8 try: return _decode_pkcs8_private(key_data), end except KeyImportError: # PKCS#8 failed - try PKCS#1 instead pass # If that fails, try each of the possible PKCS#1 encodings for pem_name in _pem_map: try: return _decode_pkcs1_private(pem_name, key_data), end except KeyImportError: # Try the next PKCS#1 encoding pass raise KeyImportError('Invalid DER private key') def _decode_der_public(data): """Decode a DER format public key""" try: # pylint: disable=unpacking-non-sequence key_data, end = der_decode(data, partial_ok=True) # pylint: enable=unpacking-non-sequence except ASN1DecodeError: raise KeyImportError('Invalid DER public key') from None # First, try to decode PKCS#8 try: return _decode_pkcs8_public(key_data), end except KeyImportError: # PKCS#8 failed - try PKCS#1 instead pass # If that fails, try each of the possible PKCS#1 encodings for pem_name in _pem_map: try: return _decode_pkcs1_public(pem_name, key_data), end except KeyImportError: # Try the next PKCS#1 encoding pass raise KeyImportError('Invalid DER public key') def _decode_der_certificate(data): """Decode a DER format X.509 certificate""" return SSHX509Certificate.construct(data) def _decode_der_certificate_list(data): """Decode a DER format X.509 certificate list""" certs = [] while data: try: _, end = der_decode(data, partial_ok=True) except ASN1DecodeError: raise KeyImportError('Invalid DER certificate') from None certs.append(_decode_der_certificate(data[:end])) data = data[end:] return certs def _decode_pem(lines, keytype): """Decode a PEM format key""" start = None line = '' for i, line in enumerate(lines): line = line.strip() if (line.startswith(b'-----BEGIN ') and line.endswith(b' ' + keytype + b'-----')): start = i+1 break if not start: raise KeyImportError('Missing PEM header of type %s' % keytype.decode('ascii')) pem_name = line[11:-(6+len(keytype))].strip() if pem_name: keytype = pem_name + b' ' + keytype headers = {} for start, line in enumerate(lines[start:], start): line = line.strip() if b':' in line: hdr, value = line.split(b':') headers[hdr.strip()] = value.strip() else: break end = None tail = b'-----END ' + keytype + b'-----' for i, line in enumerate(lines[start:], start): line = line.strip() if line == tail: end = i break if not end: raise KeyImportError('Missing PEM footer') try: data = binascii.a2b_base64(b''.join(lines[start:end])) except binascii.Error: raise KeyImportError('Invalid PEM data') from None return pem_name, headers, data, end+1 def _decode_pem_private(lines, passphrase): """Decode a PEM format private key""" pem_name, headers, data, end = _decode_pem(lines, b'PRIVATE KEY') if pem_name == b'OPENSSH': return _decode_openssh_private(data, passphrase), end if headers.get(b'Proc-Type') == b'4,ENCRYPTED': if passphrase is None: raise KeyImportError('Passphrase must be specified to import ' 'encrypted private keys') dek_info = headers.get(b'DEK-Info', b'').split(b',') if len(dek_info) != 2: raise KeyImportError('Invalid PEM encryption params') alg, iv = dek_info try: iv = binascii.a2b_hex(iv) except binascii.Error: raise KeyImportError('Invalid PEM encryption params') from None try: data = pkcs1_decrypt(data, alg, iv, passphrase) except KeyEncryptionError: raise KeyImportError('Unable to decrypt PKCS#1 ' 'private key') from None try: key_data = der_decode(data) except ASN1DecodeError: raise KeyImportError('Invalid PEM private key') from None if pem_name == b'ENCRYPTED': if passphrase is None: raise KeyImportError('Passphrase must be specified to import ' 'encrypted private keys') pem_name = b'' try: key_data = pkcs8_decrypt(key_data, passphrase) except KeyEncryptionError: raise KeyImportError('Unable to decrypt PKCS#8 ' 'private key') from None if pem_name: return _decode_pkcs1_private(pem_name, key_data), end else: return _decode_pkcs8_private(key_data), end def _decode_pem_public(lines): """Decode a PEM format public key""" pem_name, _, data, end = _decode_pem(lines, b'PUBLIC KEY') try: key_data = der_decode(data) except ASN1DecodeError: raise KeyImportError('Invalid PEM public key') from None if pem_name: return _decode_pkcs1_public(pem_name, key_data), end else: return _decode_pkcs8_public(key_data), end def _decode_pem_certificate(lines): """Decode a PEM format X.509 certificate""" pem_name, _, data, end = _decode_pem(lines, b'CERTIFICATE') if pem_name: raise KeyImportError('Invalid PEM certificate') return SSHX509Certificate.construct(data), end def _decode_pem_certificate_list(lines): """Decode a PEM format X.509 certificate list""" certs = [] while lines: cert, end = _decode_pem_certificate(lines) certs.append(cert) lines = lines[end:] return certs def _decode_openssh(line): """Decode an OpenSSH format public key or certificate""" line = line.split(None, 2) if len(line) < 2: raise KeyImportError('Invalid OpenSSH public key or certificate') elif len(line) == 2: comment = None else: comment = line[2] try: return line[0], binascii.a2b_base64(line[1]), comment except binascii.Error: raise KeyImportError('Invalid OpenSSH public key ' 'or certificate') from None def _decode_rfc4716(lines): """Decode an RFC 4716 format public key""" start = None for i, line in enumerate(lines): line = line.strip() if line == b'---- BEGIN SSH2 PUBLIC KEY ----': start = i+1 break if not start: raise KeyImportError('Missing RFC 4716 header') hdr = b'' comment = None for start, line in enumerate(lines[start:], start): line = line.strip() if line[-1:] == b'\\': hdr += line[:-1] else: hdr += line if b':' in hdr: hdr, value = hdr.split(b':') if hdr.strip() == b'Comment': comment = value.strip() if comment[:1] == b'"' and comment[-1:] == b'"': comment = comment[1:-1] hdr = b'' else: break end = None for i, line in enumerate(lines[start:], start): line = line.strip() if line == b'---- END SSH2 PUBLIC KEY ----': end = i break if not end: raise KeyImportError('Missing RFC 4716 footer') try: return binascii.a2b_base64(b''.join(lines[start:end])), comment, end+1 except binascii.Error: raise KeyImportError('Invalid RFC 4716 public key ' 'or certificate') from None def register_public_key_alg(algorithm, handler, sig_algorithms=None): """Register a new public key algorithm""" if not sig_algorithms: sig_algorithms = handler.sig_algorithms _public_key_alg_map[algorithm] = handler _public_key_algs.extend(sig_algorithms) if handler.pem_name: _pem_map[handler.pem_name] = handler if handler.pkcs8_oid: _pkcs8_oid_map[handler.pkcs8_oid] = handler def register_certificate_alg(version, algorithm, cert_algorithm, key_handler, cert_handler): """Register a new certificate algorithm""" _certificate_alg_map[cert_algorithm] = (key_handler, cert_handler) _certificate_algs.append(cert_algorithm) _certificate_version_map[algorithm, version] = \ (cert_algorithm, cert_handler) def register_x509_certificate_alg(cert_algorithm): """Register a new X.509 certificate algorithm""" if _x509_available: # pragma: no branch _certificate_alg_map[cert_algorithm] = (None, SSHX509CertificateChain) _x509_certificate_algs.append(cert_algorithm) def get_public_key_algs(): """Return supported public key algorithms""" return _public_key_algs def get_certificate_algs(): """Return supported certificate-based public key algorithms""" return _certificate_algs def get_x509_certificate_algs(): """Return supported X.509 certificate-based public key algorithms""" return _x509_certificate_algs def decode_ssh_public_key(data): """Decode a packetized SSH public key""" try: packet = SSHPacket(data) alg = packet.get_string() handler = _public_key_alg_map.get(alg) if handler: key_params = handler.decode_ssh_public(packet) packet.check_end() key = handler.make_public(*key_params) key.algorithm = alg return key else: raise KeyImportError('Unknown key algorithm: %s' % alg.decode('ascii', errors='replace')) except PacketDecodeError: raise KeyImportError('Invalid public key') from None def decode_ssh_certificate(data, comment=None): """Decode a packetized SSH certificate""" try: packet = SSHPacket(data) alg = packet.get_string() key_handler, cert_handler = _certificate_alg_map.get(alg, (None, None)) if cert_handler: return cert_handler.construct(packet, alg, key_handler, comment) else: raise KeyImportError('Unknown certificate algorithm: %s' % alg.decode('ascii', errors='replace')) except (PacketDecodeError, ValueError): raise KeyImportError('Invalid OpenSSH certificate') from None def generate_private_key(alg_name, comment=None, **kwargs): """Generate a new private key This function generates a new private key of a type matching the requested SSH algorithm. Depending on the algorithm, additional parameters can be passed which affect the generated key. Available algorithms include: ssh-dss, ssh-rsa, ecdsa-sha2-nistp256, ecdsa-sha2-nistp384, ecdsa-sha2-nistp521, ssh-ed25519 For ssh-dss, no parameters are supported. The key size is fixed at 1024 bits due to the use of SHA1 signatures. For ssh-rsa, the key size can be specified using the ``key_size`` parameter, and the RSA public exponent can be changed using the ``exponent`` parameter. By default, generated keys are 2048 bits with a public exponent of 65537. For ecdsa, the curve to use is part of the SSH algorithm name and that determines the key size. No other parameters are supported. For ssh-ed25519, no parameters are supported. The key size is fixed by the algorithm at 256 bits. :param str alg_name: The SSH algorithm name corresponding to the desired type of key. :param comment: (optional) A comment to associate with this key. :param int key_size: (optional) The key size in bits for RSA keys. :param int exponent: (optional) The public exponent for RSA keys. :type comment: `str` or ``None`` :returns: An :class:`SSHKey` private key :raises: :exc:`KeyGenerationError` if the requested key parameters are unsupported """ algorithm = alg_name.encode('utf-8') handler = _public_key_alg_map.get(algorithm) if handler: try: key = handler.generate(algorithm, **kwargs) except (TypeError, ValueError) as exc: raise KeyGenerationError(str(exc)) from None else: raise KeyGenerationError('Unknown algorithm: %s' % alg_name) key.set_comment(comment) return key def import_private_key(data, passphrase=None): """Import a private key This function imports a private key encoded in PKCS#1 or PKCS#8 DER or PEM format or OpenSSH format. Encrypted private keys can be imported by specifying the passphrase needed to decrypt them. :param data: The data to import. :param str passphrase: (optional) The passphrase to use to decrypt the key. :type data: bytes or ASCII string :returns: An :class:`SSHKey` private key """ if isinstance(data, str): try: data = data.encode('ascii') except UnicodeEncodeError: raise KeyImportError('Invalid encoding for private key') from None stripped_key = data.lstrip() if stripped_key.startswith(b'-----'): key, _ = _decode_pem_private(stripped_key.splitlines(), passphrase) else: key, _ = _decode_der_private(data, passphrase) return key def import_private_key_and_certs(data, passphrase=None): """Import a private key and optional certificate chain""" stripped_key = data.lstrip() if stripped_key.startswith(b'-----'): lines = stripped_key.splitlines() key, end = _decode_pem_private(lines, passphrase) lines = lines[end:] certs = _decode_pem_certificate_list(lines) if any(lines) else None else: key, end = _decode_der_private(data, passphrase) data = data[end:] certs = _decode_der_certificate_list(data) if data else None if certs: chain = SSHX509CertificateChain.construct_from_certs(certs) else: chain = None return key, chain def import_public_key(data): """Import a public key This function imports a public key encoded in OpenSSH, RFC4716, or PKCS#1 or PKCS#8 DER or PEM format. :param data: The data to import. :type data: bytes or ASCII string :returns: An :class:`SSHKey` public key """ if isinstance(data, str): try: data = data.encode('ascii') except UnicodeEncodeError: raise KeyImportError('Invalid encoding for public key') from None stripped_key = data.lstrip() if stripped_key.startswith(b'-----'): key, _ = _decode_pem_public(stripped_key.splitlines()) elif stripped_key.startswith(b'---- '): data, comment, _ = _decode_rfc4716(stripped_key.splitlines()) key = decode_ssh_public_key(data) key.set_comment(comment) elif data.startswith(b'\x30'): key, _ = _decode_der_public(data) elif data: algorithm, data, comment = _decode_openssh(stripped_key.splitlines()[0]) key = decode_ssh_public_key(data) if algorithm != key.algorithm: raise KeyImportError('Public key algorithm mismatch') key.set_comment(comment) else: raise KeyImportError('Invalid public key') return key def import_certificate(data): """Import a certificate This function imports an SSH certificate in DER, PEM, OpenSSH, or RFC4716 format. :param data: The data to import. :type data: bytes or ASCII string :returns: An :class:`SSHCertificate` object """ if isinstance(data, str): try: data = data.encode('ascii') except UnicodeEncodeError: raise KeyImportError('Invalid encoding for certificate') from None stripped_key = data.lstrip() if stripped_key.startswith(b'-----'): cert, _ = _decode_pem_certificate(stripped_key.splitlines()) elif data.startswith(b'\x30'): cert = _decode_der_certificate(data) elif stripped_key.startswith(b'---- '): data, comment, _ = _decode_rfc4716(stripped_key.splitlines()) cert = decode_ssh_certificate(data, comment) else: algorithm, data, comment = _decode_openssh(stripped_key.splitlines()[0]) if algorithm.startswith(b'x509v3-'): cert = _decode_der_certificate(data) else: cert = decode_ssh_certificate(data, comment) return cert def import_certificate_subject(data): """Import an X.509 certificate subject name""" try: algorithm, data = data.strip().split(None, 1) except ValueError: raise KeyImportError('Missing certificate subject algorithm') from None if algorithm.startswith('x509v3-'): match = _subject_pattern.match(data) if match: return data[match.end():] raise KeyImportError('Invalid certificate subject') def read_private_key(filename, passphrase=None): """Read a private key from a file This function reads a private key from a file. See the function :func:`import_private_key` for information about the formats supported. :param str filename: The file to read the key from. :param str passphrase: (optional) The passphrase to use to decrypt the key. :returns: An :class:`SSHKey` private key """ with open(filename, 'rb') as f: key = import_private_key(f.read(), passphrase) if not key.get_comment(): key.set_comment(filename) return key def read_private_key_and_certs(filename, passphrase=None): """Read a private key and optional certificate chain from a file""" with open(filename, 'rb') as f: key, cert = import_private_key_and_certs(f.read(), passphrase) if not key.get_comment(): key.set_comment(filename) return key, cert def read_public_key(filename): """Read a public key from a file This function reads a public key from a file. See the function :func:`import_public_key` for information about the formats supported. :param str filename: The file to read the key from. :returns: An :class:`SSHKey` public key """ with open(filename, 'rb') as f: key = import_public_key(f.read()) if not key.get_comment(): key.set_comment(filename) return key def read_certificate(filename): """Read a certificate from a file This function reads an SSH certificate from a file. See the function :func:`import_certificate` for information about the formats supported. :param str filename: The file to read the certificate from. :returns: An :class:`SSHCertificate` object """ with open(filename, 'rb') as f: return import_certificate(f.read()) def read_private_key_list(filename, passphrase=None): """Read a list of private keys from a file This function reads a list of private keys from a file. See the function :func:`import_private_key` for information about the formats supported. If any of the keys are encrypted, they must all be encrypted with the same passphrase. :param str filename: The file to read the keys from. :param str passphrase: (optional) The passphrase to use to decrypt the keys. :returns: A list of :class:`SSHKey` private keys """ with open(filename, 'rb') as f: data = f.read() keys = [] stripped_key = data.strip() if stripped_key.startswith(b'-----'): lines = stripped_key.splitlines() while lines: key, end = _decode_pem_private(lines, passphrase) keys.append(key) lines = lines[end:] else: while data: key, end = _decode_der_private(data, passphrase) keys.append(key) data = data[end:] for key in keys: if not key.get_comment(): key.set_comment(filename) return keys def read_public_key_list(filename): """Read a list of public keys from a file This function reads a list of public keys from a file. See the function :func:`import_public_key` for information about the formats supported. :param str filename: The file to read the keys from. :returns: A list of :class:`SSHKey` public keys """ with open(filename, 'rb') as f: data = f.read() keys = [] stripped_key = data.strip() if stripped_key.startswith(b'-----'): lines = stripped_key.splitlines() while lines: key, end = _decode_pem_public(lines) keys.append(key) lines = lines[end:] elif stripped_key.startswith(b'---- '): lines = stripped_key.splitlines() while lines: data, comment, end = _decode_rfc4716(lines) key = decode_ssh_public_key(data) key.set_comment(comment) keys.append(key) lines = lines[end:] elif data.startswith(b'\x30'): while data: key, end = _decode_der_public(data) keys.append(key) data = data[end:] else: for line in stripped_key.splitlines(): algorithm, data, comment = _decode_openssh(line) key = decode_ssh_public_key(data) if algorithm != key.algorithm: raise KeyImportError('Public key algorithm mismatch') key.set_comment(comment) keys.append(key) for key in keys: if not key.get_comment(): key.set_comment(filename) return keys def read_certificate_list(filename): """Read a list of certificates from a file This function reads a list of SSH certificates from a file. See the function :func:`import_certificate` for information about the formats supported. :param str filename: The file to read the certificates from. :returns: A list of :class:`SSHCertificate` certificates """ with open(filename, 'rb') as f: data = f.read() certs = [] stripped_key = data.strip() if stripped_key.startswith(b'-----'): certs = _decode_pem_certificate_list(stripped_key.splitlines()) elif data.startswith(b'\x30'): certs = _decode_der_certificate_list(data) elif stripped_key.startswith(b'---- '): lines = stripped_key.splitlines() while lines: data, comment, end = _decode_rfc4716(lines) certs.append(decode_ssh_certificate(data, comment)) lines = lines[end:] else: for line in stripped_key.splitlines(): algorithm, data, comment = _decode_openssh(line) if algorithm.startswith(b'x509v3-'): cert = _decode_der_certificate(data) else: cert = decode_ssh_certificate(data, comment) certs.append(cert) return certs def load_keypairs(keylist, passphrase=None): """Load SSH private keys and optional matching certificates This function loads a list of SSH keys and optional matching certificates. When certificates are specified, the private key is added to the list both with and without the certificate. :param keylist: The list of private keys and certificates to load. :param str passphrase: (optional) The passphrase to use to decrypt private keys. :type keylist: *see* :ref:`SpecifyingPrivateKeys` :returns: A list of :class:`SSHKeyPair` objects """ result = [] if isinstance(keylist, str): try: keys = read_private_key_list(keylist, passphrase) if len(keys) > 1: return [SSHLocalKeyPair(key) for key in keys] except KeyImportError: pass keylist = [keylist] elif isinstance(keylist, (tuple, bytes, SSHKey, SSHKeyPair)): keylist = [keylist] elif not keylist: keylist = [] for key in keylist: if isinstance(key, SSHKeyPair): result.append(key) else: allow_certs = False default_cert_file = None ignore_missing_cert = False if isinstance(key, str): allow_certs = True default_cert_file = key + '-cert.pub' ignore_missing_cert = True elif isinstance(key, bytes): allow_certs = True elif isinstance(key, tuple): key, certs = key else: certs = None if isinstance(key, str): if allow_certs: key, certs = read_private_key_and_certs(key, passphrase) if not certs and default_cert_file: certs = default_cert_file else: key = read_private_key(key, passphrase) elif isinstance(key, bytes): if allow_certs: key, certs = import_private_key_and_certs(key, passphrase) else: key = import_private_key(key, passphrase) if certs: try: certs = load_certificates(certs) except OSError: if ignore_missing_cert: certs = None else: raise if certs is None: cert = None elif len(certs) == 1 and not certs[0].is_x509: cert = certs[0] else: cert = SSHX509CertificateChain.construct_from_certs(certs) if cert: result.append(SSHLocalKeyPair(key, cert)) result.append(SSHLocalKeyPair(key, None)) return result def load_public_keys(keylist): """Load public keys This function loads a list of SSH public keys. :param keylist: The list of public keys to load. :type keylist: *see* :ref:`SpecifyingPublicKeys` :returns: A list of :class:`SSHKey` objects """ if isinstance(keylist, str): return read_public_key_list(keylist) else: result = [] for key in keylist: if isinstance(key, str): key = read_public_key(key) elif isinstance(key, bytes): key = import_public_key(key) result.append(key) return result def load_certificates(certlist): """Load certificates This function loads a list of OpenSSH or X.509 certificates. :param certlist: The list of certificates to load. :type certlist: *see* :ref:`SpecifyingCertificates` :returns: A list of :class:`SSHCertificate` objects """ if isinstance(certlist, SSHCertificate): return [certlist] elif isinstance(certlist, (bytes, str)): certlist = [certlist] result = [] for cert in certlist: if isinstance(cert, str): certs = read_certificate_list(cert) elif isinstance(cert, bytes): certs = [import_certificate(cert)] elif isinstance(cert, SSHCertificate): certs = [cert] else: certs = cert result.extend(certs) return result asyncssh-1.11.1/asyncssh/rsa.py000066400000000000000000000153771320320510200164210ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """RSA public key encryption handler""" from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode from .crypto import RSAPrivateKey, RSAPublicKey from .misc import all_ints from .packet import MPInt from .public_key import SSHKey, SSHOpenSSHCertificateV01, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_x509_certificate_alg # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _RSAKey(SSHKey): """Handler for RSA public key encryption""" algorithm = b'ssh-rsa' pem_name = b'RSA' pkcs8_oid = ObjectIdentifier('1.2.840.113549.1.1.1') sig_algorithms = (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa') x509_sig_algorithms = (b'rsa2048-sha256', b'ssh-rsa') x509_algorithms = tuple(b'x509v3-' + alg for alg in x509_sig_algorithms) all_sig_algorithms = set(x509_sig_algorithms + sig_algorithms) def __eq__(self, other): # This isn't protected access - both objects are _RSAKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.n == other._key.n and self._key.e == other._key.e and self._key.d == other._key.d) def __hash__(self): return hash((self._key.n, self._key.e, self._key.d, self._key.p, self._key.q)) @classmethod def generate(cls, algorithm, *, key_size=2048, exponent=65537): """Generate a new RSA private key""" # pylint: disable=unused-argument return cls(RSAPrivateKey.generate(key_size, exponent)) @classmethod def make_private(cls, *args): """Construct an RSA private key""" return cls(RSAPrivateKey.construct(*args)) @classmethod def make_public(cls, *args): """Construct an RSA public key""" return cls(RSAPublicKey.construct(*args)) @classmethod def decode_pkcs1_private(cls, key_data): """Decode a PKCS#1 format RSA private key""" if (isinstance(key_data, tuple) and all_ints(key_data) and len(key_data) >= 9): return key_data[1:9] else: return None @classmethod def decode_pkcs1_public(cls, key_data): """Decode a PKCS#1 format RSA public key""" if (isinstance(key_data, tuple) and all_ints(key_data) and len(key_data) == 2): return key_data else: return None @classmethod def decode_pkcs8_private(cls, alg_params, data): """Decode a PKCS#8 format RSA private key""" if alg_params is not None: return None try: key_data = der_decode(data) except ASN1DecodeError: return None return cls.decode_pkcs1_private(key_data) @classmethod def decode_pkcs8_public(cls, alg_params, data): """Decode a PKCS#8 format RSA public key""" if alg_params is not None: return None try: key_data = der_decode(data) except ASN1DecodeError: return None return cls.decode_pkcs1_public(key_data) @classmethod def decode_ssh_private(cls, packet): """Decode an SSH format RSA private key""" n = packet.get_mpint() e = packet.get_mpint() d = packet.get_mpint() iqmp = packet.get_mpint() p = packet.get_mpint() q = packet.get_mpint() return n, e, d, p, q, d % (p-1), d % (q-1), iqmp @classmethod def decode_ssh_public(cls, packet): """Decode an SSH format RSA public key""" e = packet.get_mpint() n = packet.get_mpint() return n, e def encode_pkcs1_private(self): """Encode a PKCS#1 format RSA private key""" if not self._key.d: raise KeyExportError('Key is not private') return (0, self._key.n, self._key.e, self._key.d, self._key.p, self._key.q, self._key.dmp1, self._key.dmq1, self._key.iqmp) def encode_pkcs1_public(self): """Encode a PKCS#1 format RSA public key""" return self._key.n, self._key.e def encode_pkcs8_private(self): """Encode a PKCS#8 format RSA private key""" return None, der_encode(self.encode_pkcs1_private()) def encode_pkcs8_public(self): """Encode a PKCS#8 format RSA public key""" return None, der_encode(self.encode_pkcs1_public()) def encode_ssh_private(self): """Encode an SSH format RSA private key""" if not self._key.d: raise KeyExportError('Key is not private') return b''.join((MPInt(self._key.n), MPInt(self._key.e), MPInt(self._key.d), MPInt(self._key.iqmp), MPInt(self._key.p), MPInt(self._key.q))) def encode_ssh_public(self): """Encode an SSH format RSA public key""" return b''.join((MPInt(self._key.e), MPInt(self._key.n))) def encode_agent_cert_private(self): """Encode RSA certificate private key data for agent""" if not self._key.d: raise KeyExportError('Key is not private') return b''.join((MPInt(self._key.d), MPInt(self._key.iqmp), MPInt(self._key.p), MPInt(self._key.q))) def sign_der(self, data, sig_algorithm): """Compute a DER-encoded signature of the specified data""" # pylint: disable=unused-argument if not self._key.d: raise ValueError('Private key needed for signing') return self._key.sign(data, sig_algorithm) def verify_der(self, data, sig_algorithm, sig): """Verify a DER-encoded signature of the specified data""" # pylint: disable=unused-argument return self._key.verify(data, sig, sig_algorithm) def sign_ssh(self, data, sig_algorithm): """Compute an SSH-encoded signature of the specified data""" return self.sign_der(data, sig_algorithm) def verify_ssh(self, data, sig_algorithm, sig): """Verify an SSH-encoded signature of the specified data""" return self.verify_der(data, sig_algorithm, sig) register_public_key_alg(b'ssh-rsa', _RSAKey) register_certificate_alg(1, b'ssh-rsa', b'ssh-rsa-cert-v01@openssh.com', _RSAKey, SSHOpenSSHCertificateV01) for alg in _RSAKey.x509_algorithms: register_x509_certificate_alg(alg) asyncssh-1.11.1/asyncssh/saslprep.py000066400000000000000000000061101320320510200174460ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SASLprep implementation This module implements the stringprep algorithm defined in RFC 3454 and the SASLprep profile of stringprep defined in RFC 4013. This profile is used to normalize usernames and passwords sent in the SSH protocol. """ # The stringprep module should not be flagged as deprecated # pylint: disable=deprecated-module import stringprep # pylint: enable=deprecated-module import unicodedata class SASLPrepError(ValueError): """Invalid data provided to saslprep""" def _check_bidi(s): """Enforce bidirectional character check from RFC 3454 (stringprep)""" r_and_al_cat = False l_cat = False for c in s: if not r_and_al_cat and stringprep.in_table_d1(c): r_and_al_cat = True if not l_cat and stringprep.in_table_d2(c): l_cat = True if r_and_al_cat and l_cat: raise SASLPrepError('Both RandALCat and LCat characters present') if r_and_al_cat and not (stringprep.in_table_d1(s[0]) and stringprep.in_table_d1(s[-1])): raise SASLPrepError('RandALCat character not at both start and end') def _stringprep(s, check_unassigned, mapping, normalization, prohibited, bidi): """Implement a stringprep profile as defined in RFC 3454""" if not isinstance(s, str): raise TypeError('argument 0 must be str, not %s' % type(s).__name__) if check_unassigned: # pragma: no branch for c in s: if stringprep.in_table_a1(c): raise SASLPrepError('Unassigned character: %r' % c) if mapping: # pragma: no branch s = mapping(s) if normalization: # pragma: no branch s = unicodedata.normalize(normalization, s) if prohibited: # pragma: no branch for c in s: for lookup in prohibited: if lookup(c): raise SASLPrepError('Prohibited character: %r' % c) if bidi: # pragma: no branch _check_bidi(s) return s def _map_saslprep(s): """Map stringprep table B.1 to nothing and C.1.2 to ASCII space""" r = [] for c in s: if stringprep.in_table_c12(c): r.append(' ') elif not stringprep.in_table_b1(c): r.append(c) return ''.join(r) def saslprep(s): """Implement SASLprep profile defined in RFC 4013""" prohibited = (stringprep.in_table_c12, stringprep.in_table_c21_c22, stringprep.in_table_c3, stringprep.in_table_c4, stringprep.in_table_c5, stringprep.in_table_c6, stringprep.in_table_c7, stringprep.in_table_c8, stringprep.in_table_c9) return _stringprep(s, True, _map_saslprep, 'NFKC', prohibited, True) asyncssh-1.11.1/asyncssh/scp.py000066400000000000000000000663361320320510200164220ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # Jonathan Slenders - proposed changes to allow SFTP server callbacks # to be coroutines """SCP handlers""" import argparse import asyncio import posixpath import shlex import stat from .constants import DEFAULT_LANG from .constants import FX_BAD_MESSAGE, FX_CONNECTION_LOST, FX_FAILURE from .sftp import LocalFile, match_glob from .sftp import SFTP_BLOCK_SIZE, SFTPAttrs, SFTPError, SFTPServerFile def _parse_cd_args(args): """Parse arguments to an SCP copy or dir request""" try: permissions, size, name = args.split() return int(permissions, 8), int(size), name except ValueError: raise SCPError(FX_BAD_MESSAGE, 'Invalid copy or dir request') from None def _parse_t_args(args): """Parse argument to an SCP time request""" try: atime, _, mtime, _ = args.split() return int(atime), int(mtime) except ValueError: raise SCPError(FX_BAD_MESSAGE, 'Invalid time request') from None @asyncio.coroutine def _parse_path(path): """Convert an SCP path into an SSHClientConnection and path""" from . import connect if isinstance(path, tuple): conn, path = path elif isinstance(path, str) and ':' in path: conn, path = path.split(':') elif isinstance(path, bytes) and b':' in path: conn, path = path.split(b':') elif isinstance(path, (str, bytes)): conn = None else: conn = path path = b'.' if isinstance(conn, (str, bytes)): close_conn = True conn = yield from connect(conn) elif isinstance(conn, tuple): close_conn = True conn = yield from connect(*conn) else: close_conn = False return conn, path, close_conn @asyncio.coroutine def _start_remote(conn, source, must_be_dir, preserve, recurse, path): """Start remote SCP server""" if isinstance(path, str): path = path.encode('utf-8') command = (b'scp ' + (b'-f ' if source else b'-t ') + (b'-d ' if must_be_dir else b'') + (b'-p ' if preserve else b'') + (b'-r ' if recurse else b'') + path) writer, reader, _ = yield from conn.open_session(command, encoding=None) return reader, writer class SCPError(SFTPError): """SCP error""" def __init__(self, code, reason, path=None, fatal=False, suppress_send=False, lang=DEFAULT_LANG): if isinstance(reason, bytes): reason = reason.decode('utf-8', errors='replace') if isinstance(path, bytes): path = path.decode('utf-8', errors='replace') if path: reason = reason + ': ' + path super().__init__(code, reason, lang) self.fatal = fatal self.suppress_send = suppress_send class _SCPArgParser(argparse.ArgumentParser): """A parser for SCP arguments""" def __init__(self): super().__init__(add_help=False) group = self.add_mutually_exclusive_group(required=True) group.add_argument('-f', dest='source', action='store_true') group.add_argument('-t', dest='source', action='store_false') self.add_argument('-d', dest='must_be_dir', action='store_true') self.add_argument('-p', dest='preserve', action='store_true') self.add_argument('-r', dest='recurse', action='store_true') self.add_argument('-v', dest='verbose', action='store_true') self.add_argument('path') def error(self, message): raise ValueError(message) def parse(self, command): """Parse an SCP command""" return self.parse_args(shlex.split(command)[1:]) class _SCPHandler: """SCP handler""" def __init__(self, reader, writer, error_handler=None): self._reader = reader self._writer = writer self._error_handler = error_handler @asyncio.coroutine def await_response(self): """Wait for an SCP response""" result = yield from self._reader.read(1) if result != b'\0': reason = yield from self._reader.readline() if not result or not reason.endswith(b'\n'): raise SCPError(FX_CONNECTION_LOST, 'Connection lost', fatal=True, suppress_send=True) if result not in b'\x01\x02': reason = result + reason return SCPError(FX_FAILURE, reason[:-1], fatal=result != b'\x01', suppress_send=True) return None def send_request(self, *args): """Send an SCP request""" self._writer.write(b''.join(args) + b'\n') @asyncio.coroutine def make_request(self, *args): """Send an SCP request and wait for a response""" self.send_request(*args) exc = yield from self.await_response() if exc: raise exc @asyncio.coroutine def send_data(self, data): """Send SCP file data""" self._writer.write(data) yield from self._writer.drain() def send_ok(self): """Send an SCP OK response""" self._writer.write(b'\0') def send_error(self, exc): """Send an SCP error response""" if isinstance(exc, SFTPError): reason = exc.reason.encode('utf-8') elif isinstance(exc, OSError): # pragma: no branch (win32) reason = exc.strerror.encode('utf-8') if exc.filename: if isinstance(exc.filename, str): # pragma: no cover (win32) exc.filename = exc.filename.encode('utf-8') reason += b': ' + exc.filename else: # pragma: no cover (win32) reason = str(exc).encode('utf-8') fatal = getattr(exc, 'fatal', False) self._writer.write((b'\x02' if fatal else b'\x01') + b'scp: ' + reason + b'\n') @asyncio.coroutine def recv_request(self): """Receive SCP request""" request = yield from self._reader.readline() if not request: return None, None return request[:1], request[1:-1] @asyncio.coroutine def recv_data(self, n): """Receive SCP file data""" return (yield from self._reader.read(n)) def handle_error(self, exc): """Handle an SCP error""" if isinstance(exc, BrokenPipeError): exc = SCPError(FX_CONNECTION_LOST, 'Connection lost', fatal=True, suppress_send=True) if not getattr(exc, 'suppress_send', False): self.send_error(exc) if getattr(exc, 'fatal', False) or self._error_handler is None: raise exc from None elif self._error_handler: self._error_handler(exc) def close(self): """Close an SCP session""" self._writer.close() class _SCPSource(_SCPHandler): """SCP handler for sending files""" def __init__(self, fs, reader, writer, preserve, recurse, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): super().__init__(reader, writer, error_handler) self._fs = fs self._preserve = preserve self._recurse = recurse self._block_size = block_size self._progress_handler = progress_handler @asyncio.coroutine def _make_cd_request(self, action, attrs, size, path): """Make an SCP copy or dir request""" args = '%04o %d ' % (attrs.permissions & 0o7777, size) yield from self.make_request(action, args.encode('ascii'), posixpath.basename(path)) @asyncio.coroutine def _make_t_request(self, attrs): """Make an SCP time request""" args = '%d 0 %d 0' % (attrs.atime, attrs.mtime) yield from self.make_request(b'T', args.encode('ascii')) @asyncio.coroutine def _send_file(self, srcpath, dstpath, attrs): """Send a file over SCP""" file_obj = yield from self._fs.open(srcpath, 'rb') size = attrs.size local_exc = None offset = 0 try: yield from self._make_cd_request(b'C', attrs, size, srcpath) while offset < size: blocklen = min(size - offset, self._block_size) if local_exc: data = blocklen * b'\0' else: try: data = yield from file_obj.read(blocklen, offset) if not data: raise SCPError(FX_FAILURE, 'Unexpected EOF') except (OSError, SFTPError) as exc: local_exc = exc yield from self.send_data(data) offset += len(data) if self._progress_handler: self._progress_handler(srcpath, dstpath, offset, size) finally: yield from file_obj.close() if local_exc: self.send_error(local_exc) local_exc.suppress_send = True else: self.send_ok() remote_exc = yield from self.await_response() exc = remote_exc or local_exc if exc: raise exc @asyncio.coroutine def _send_dir(self, srcpath, dstpath, attrs): """Send directory over SCP""" yield from self._make_cd_request(b'D', attrs, 0, srcpath) for name in (yield from self._fs.listdir(srcpath)): if name in (b'.', b'..'): continue yield from self._send_files(posixpath.join(srcpath, name), posixpath.join(dstpath, name)) yield from self.make_request(b'E') @asyncio.coroutine def _send_files(self, srcpath, dstpath): """Send files via SCP""" try: attrs = yield from self._fs.stat(srcpath) if self._preserve: yield from self._make_t_request(attrs) if self._recurse and stat.S_ISDIR(attrs.permissions): yield from self._send_dir(srcpath, dstpath, attrs) elif stat.S_ISREG(attrs.permissions): yield from self._send_file(srcpath, dstpath, attrs) else: raise SCPError(FX_FAILURE, 'Not a regular file', srcpath) except (OSError, SFTPError, ValueError) as exc: self.handle_error(exc) @asyncio.coroutine def run(self, srcpath): """Start SCP transfer""" try: if isinstance(srcpath, str): srcpath = srcpath.encode('utf-8') exc = yield from self.await_response() if exc: raise exc for path in (yield from match_glob(self._fs, srcpath)): yield from self._send_files(path, b'') except (OSError, SFTPError) as exc: self.handle_error(exc) finally: self.close() class _SCPSink(_SCPHandler): """SCP handler for receiving files""" def __init__(self, fs, reader, writer, must_be_dir, preserve, recurse, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): super().__init__(reader, writer, error_handler) self._fs = fs self._must_be_dir = must_be_dir self._preserve = preserve self._recurse = recurse self._block_size = block_size self._progress_handler = progress_handler @asyncio.coroutine def _recv_file(self, srcpath, dstpath, size): """Receive a file via SCP""" file_obj = yield from self._fs.open(dstpath, 'wb') local_exc = None offset = 0 try: self.send_ok() while offset < size: blocklen = min(size - offset, self._block_size) data = yield from self.recv_data(blocklen) if not data: raise SCPError(FX_CONNECTION_LOST, 'Connection lost', fatal=True, suppress_send=True) if not local_exc: try: yield from file_obj.write(data, offset) except (OSError, SFTPError) as exc: local_exc = exc offset += len(data) if self._progress_handler: self._progress_handler(srcpath, dstpath, offset, size) finally: yield from file_obj.close() remote_exc = yield from self.await_response() if local_exc: self.send_error(local_exc) local_exc.suppress_send = True else: self.send_ok() exc = remote_exc or local_exc if exc: raise exc @asyncio.coroutine def _recv_dir(self, srcpath, dstpath): """Receive a directory over SCP""" if not self._recurse: raise SCPError(FX_BAD_MESSAGE, 'Directory received without recurse') if (yield from self._fs.exists(dstpath)): if not (yield from self._fs.isdir(dstpath)): raise SCPError(FX_FAILURE, 'Not a directory', dstpath) else: yield from self._fs.mkdir(dstpath) yield from self._recv_files(srcpath, dstpath) @asyncio.coroutine def _recv_files(self, srcpath, dstpath): """Receive files over SCP""" self.send_ok() attrs = SFTPAttrs() while True: action, args = yield from self.recv_request() if not action: break try: if action in b'\x01\x02': raise SCPError(FX_FAILURE, args, fatal=action != b'\x01', suppress_send=True) elif action == b'T': if self._preserve: attrs.atime, attrs.mtime = _parse_t_args(args) self.send_ok() elif action == b'E': self.send_ok() elif action in b'CD': try: attrs.permissions, size, name = _parse_cd_args(args) new_srcpath = posixpath.join(srcpath, name) if (yield from self._fs.isdir(dstpath)): new_dstpath = posixpath.join(dstpath, name) else: new_dstpath = dstpath if action == b'D': yield from self._recv_dir(new_srcpath, new_dstpath) else: yield from self._recv_file(new_srcpath, new_dstpath, size) if self._preserve: yield from self._fs.setstat(new_dstpath, attrs) finally: attrs = SFTPAttrs() else: raise SCPError(FX_BAD_MESSAGE, 'Unknown request') except (OSError, SFTPError) as exc: self.handle_error(exc) @asyncio.coroutine def run(self, dstpath): """Start SCP file receive""" try: if isinstance(dstpath, str): dstpath = dstpath.encode('utf-8') if self._must_be_dir and not (yield from self._fs.isdir(dstpath)): self.handle_error(SCPError(FX_FAILURE, 'Not a directory', dstpath)) else: yield from self._recv_files(b'', dstpath) except (OSError, SFTPError, ValueError) as exc: self.handle_error(exc) finally: self.close() class _SCPCopier: """SCP handler for remote-to-remote copies""" def __init__(self, src_reader, src_writer, dst_reader, dst_writer, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): self._source = _SCPHandler(src_reader, src_writer) self._sink = _SCPHandler(dst_reader, dst_writer) self._block_size = block_size self._progress_handler = progress_handler self._error_handler = error_handler def _handle_error(self, exc): """Handle an SCP error""" if isinstance(exc, BrokenPipeError): exc = SCPError(FX_CONNECTION_LOST, 'Connection lost', fatal=True) if self._error_handler and not getattr(exc, 'fatal', False): self._error_handler(exc) else: raise exc @asyncio.coroutine def _forward_response(self, src, dst): """Forward an SCP response between two remote SCP servers""" # pylint: disable=no-self-use try: exc = yield from src.await_response() if exc: dst.send_error(exc) return exc else: dst.send_ok() return None except OSError as exc: return exc @asyncio.coroutine def _copy_file(self, path, size): """Copy a file from one remote SCP server to another""" offset = 0 while offset < size: blocklen = min(size - offset, self._block_size) data = yield from self._source.recv_data(blocklen) if not data: raise SCPError(FX_CONNECTION_LOST, 'Connection lost', fatal=True) yield from self._sink.send_data(data) offset += len(data) if self._progress_handler: self._progress_handler(path, path, offset, size) source_exc = yield from self._forward_response(self._source, self._sink) sink_exc = yield from self._forward_response(self._sink, self._source) exc = sink_exc or source_exc if exc: self._handle_error(exc) @asyncio.coroutine def run(self): """Start SCP remote-to-remote transfer""" paths = [] try: exc = yield from self._forward_response(self._sink, self._source) if exc: self._handle_error(exc) while True: action, args = yield from self._source.recv_request() if not action: break self._sink.send_request(action, args) if action in b'\x01\x02': exc = SCPError(FX_FAILURE, args, fatal=action != b'\x01') self._handle_error(exc) continue exc = yield from self._forward_response(self._sink, self._source) if exc: self._handle_error(exc) continue if action in b'CD': _, size, name = _parse_cd_args(args) if action == b'C': path = b'/'.join(paths + [name]) yield from self._copy_file(path, size) else: paths.append(name) elif action == b'E': if paths: paths.pop() else: break elif action != b'T': raise SCPError(FX_BAD_MESSAGE, 'Unknown SCP action') except (OSError, SFTPError) as exc: self._handle_error(exc) finally: self._source.close() self._sink.close() @asyncio.coroutine def scp(srcpaths, dstpath=None, *, preserve=False, recurse=False, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): """Copy files using SCP This function is a coroutine which copies one or more files or directories using the SCP protocol. Source and destination paths can be string or bytes values to reference local files or can be a tuple of the form ``(conn, path)`` where ``conn`` is an open :class:`SSHClientConnection` to reference files and directories on a remote system. For convenience, a host name or tuple of the form ``(host, port)`` can be provided in place of the :class:`SSHClientConnection` to request that a new SSH connection be opened to a host using default connect arguments. A string or bytes value of the form ``'host:path'`` may also be used in place of the ``(conn, path)`` tuple to make a new connection to the requested host on the default SSH port. Either a single source path or a sequence of source paths can be provided, and each path can contain '*' and '?' wildcard characters which can be used to match multiple source files or directories. When copying a single file or directory, the destination path can be either the full path to copy data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the source path will be used as the destination name. When copying multiple files, the destination path must refer to a directory. If it doesn't already exist, a directory will be created with that name. If the destination path is an :class:`SSHClientConnection` without a path or the path provided is empty, files are copied into the default destination working directory. If preserve is ``True``, the access and modification times and permissions of the original files and directories are set on the copied files. However, do to the timing of when this information is sent, the preserved access time will be what was set on the source file before the copy begins. So, the access time on the source file will no longer match the destination after the transfer completes. If recurse is ``True`` and the source path points at a directory, the entire subtree under that directory is copied. Symbolic links found on the source will have the contents of their target copied rather than creating a destination symbolic link. When using this option during a recursive copy, one needs to watch out for links that result in loops. SCP does not provide a mechanism for preserving links. If you need this, consider using SFTP instead. The block_size value controls the size of read and write operations issued to copy the files. It defaults to 16 KB. If progress_handler is specified, it will be called after each block of a file is successfully copied. The arguments passed to this handler will be the relative path of the file being copied, bytes copied so far, and total bytes in the file being copied. If multiple source paths are provided or recurse is set to ``True``, the progress_handler will be called consecutively on each file being copied. If error_handler is specified and an error occurs during the copy, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple source paths are provided or when recurse is set to ``True``, to allow error information to be collected without aborting the copy of the remaining files. The error handler can raise an exception if it wants the copy to completely stop. Otherwise, after an error, the copy will continue starting with the next file. :param srcpaths: The paths of the source files or directories to copy :param dstpath: (optional) The path of the destination file or directory to copy into :param bool preserve: (optional) Whether or not to preserve the original file attributes :param bool recurse: (optional) Whether or not to recursively copy directories :param int block_size: (optional) The block size to use for file reads and writes :param callable progress_handler: (optional) The function to call to report copy progress :param callable error_handler: (optional) The function to call when an error occurs :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error | :exc:`ValueError` if both source and destination are local """ if (isinstance(srcpaths, (str, bytes)) or (isinstance(srcpaths, tuple) and len(srcpaths) == 2)): srcpaths = [srcpaths] must_be_dir = len(srcpaths) > 1 dstconn, dstpath, close_dst = yield from _parse_path(dstpath) try: for srcpath in srcpaths: srcconn, srcpath, close_src = yield from _parse_path(srcpath) try: if srcconn and dstconn: src_reader, src_writer = yield from _start_remote( srcconn, True, must_be_dir, preserve, recurse, srcpath) dst_reader, dst_writer = yield from _start_remote( dstconn, False, must_be_dir, preserve, recurse, dstpath) copier = _SCPCopier(src_reader, src_writer, dst_reader, dst_writer, block_size, progress_handler, error_handler) yield from copier.run() elif srcconn: reader, writer = yield from _start_remote( srcconn, True, must_be_dir, preserve, recurse, srcpath) sink = _SCPSink(LocalFile, reader, writer, must_be_dir, preserve, recurse, block_size, progress_handler, error_handler) yield from sink.run(dstpath) elif dstconn: reader, writer = yield from _start_remote( dstconn, False, must_be_dir, preserve, recurse, dstpath) source = _SCPSource(LocalFile, reader, writer, preserve, recurse, block_size, progress_handler, error_handler) yield from source.run(srcpath) else: raise ValueError('Local copy not supported') finally: if close_src: srcconn.close() yield from srcconn.wait_closed() finally: if close_dst: dstconn.close() yield from dstconn.wait_closed() @asyncio.coroutine def run_scp_server(sftp_server, command, stdin, stdout, stderr): """Return a handler for an SCP server session""" try: args = _SCPArgParser().parse(command) except ValueError as exc: stderr.write(b'scp: ' + str(exc).encode('utf-8') + b'\n') stderr.channel.exit(1) return fs = SFTPServerFile(sftp_server) if args.source: handler = _SCPSource(fs, stdin, stdout, args.preserve, args.recurse, error_handler=False) else: handler = _SCPSink(fs, stdin, stdout, args.must_be_dir, args.preserve, args.recurse, error_handler=False) try: yield from handler.run(args.path) finally: sftp_server.exit() asyncssh-1.11.1/asyncssh/server.py000066400000000000000000000731361320320510200171370ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH server protocol handler""" class SSHServer: """SSH server protocol handler Applications should subclass this when implementing an SSH server. At a minimum, one or more of the authentication handlers will need to be overridden to perform authentication, or :meth:`begin_auth` should be overridden to return ``False`` to indicate that no authentication is required. In addition, one or more of the :meth:`session_requested`, :meth:`connection_requested`, :meth:`server_requested`, :meth:`unix_connection_requested`, or :meth:`unix_server_requested` methods will need to be overridden to handle requests to open sessions or direct connections or set up listeners for forwarded connections. .. note:: The authentication callbacks described here can be defined as coroutines. However, they may be cancelled if they are running when the SSH connection is closed by the client. If they attempt to catch the CancelledError exception to perform cleanup, they should make sure to re-raise it to allow AsyncSSH to finish its own cleanup. """ # pylint: disable=no-self-use,unused-argument def connection_made(self, connection): """Called when a connection is made This method is called when a new TCP connection is accepted. The connection parameter should be stored if needed for later use. """ pass # pragma: no cover def connection_lost(self, exc): """Called when a connection is lost or closed This method is called when a connection is closed. If the connection is shut down cleanly, *exc* will be ``None``. Otherwise, it will be an exception explaining the reason for the disconnect. """ pass # pragma: no cover def debug_msg_received(self, msg, lang, always_display): """A debug message was received on this connection This method is called when the other end of the connection sends a debug message. Applications should implement this method if they wish to process these debug messages. :param str msg: The debug message sent :param str lang: The language the message is in :param bool always_display: Whether or not to display the message """ pass # pragma: no cover def begin_auth(self, username): """Authentication has been requested by the client This method will be called when authentication is attempted for the specified user. Applications should use this method to prepare whatever state they need to complete the authentication, such as loading in the set of authorized keys for that user. If no authentication is required for this user, this method should return ``False`` to cause the authentication to immediately succeed. Otherwise, it should return ``True`` to indicate that authentication should proceed. :param str username: The name of the user being authenticated :returns: A bool indicating whether authentication is required """ return True # pragma: no cover def validate_gss_principal(self, username, user_principal, host_principal): """Return whether a GSS principal is valid for this user This method should return ``True`` if the specified user principal is valid for the user being authenticated. It can be overridden by applications wishing to perform their authentication. If blocking operations need to be performed to determine the validity of the principal, this method may be defined as a coroutine. By default, this method will return ``True`` only when the name in the user principal exactly matches the username and the domain of the user principal matches the domain of the host principal. :param str username: The user being authenticated :param str user_principal: The user principal sent by the client :param str host_principal: The host principal sent by the server :returns: A bool indicating if the specified user principal is valid for the user being authenticated """ host_domain = host_principal.rsplit('@')[-1] return user_principal == username + '@' + host_domain def public_key_auth_supported(self): """Return whether or not public key authentication is supported This method should return ``True`` if client public key authentication is supported. Applications wishing to support it must have this method return ``True`` and implement :meth:`validate_public_key` to return whether or not the key provided by the client is valid for the user being authenticated. By default, it returns ``False`` indicating the client public key authentication is not supported. :returns: A bool indicating if public key authentication is supported or not """ return False # pragma: no cover def validate_public_key(self, username, key): """Return whether key is an authorized client key for this user Basic key-based client authentication can be supported by passing authorized keys in the ``authorized_client_keys`` argument of :func:`create_server`, or by calling :meth:`set_authorized_keys ` on the server connection from the :meth:`begin_auth` method. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return ``True`` if the specified key is a valid client key for the user being authenticated. This method may be called multiple times with different keys provided by the client. Applications should precompute as much as possible in the :meth:`begin_auth` method so that this function can quickly return whether the key provided is in the list. If blocking operations need to be performed to determine the validity of the key, this method may be defined as a coroutine. By default, this method returns ``False`` for all client keys. .. note:: This function only needs to report whether the public key provided is a valid client key for this user. If it is, AsyncSSH will verify that the client possesses the corresponding private key before allowing the authentication to succeed. :param str username: The user being authenticated :param key: The public key sent by the client :type key: :class:`SSHKey` *public key* :returns: A bool indicating if the specified key is a valid client key for the user being authenticated """ return False # pragma: no cover def validate_ca_key(self, username, key): """Return whether key is an authorized CA key for this user Basic key-based client authentication can be supported by passing authorized keys in the ``authorized_client_keys`` argument of :func:`create_server`, or by calling :meth:`set_authorized_keys ` on the server connection from the :meth:`begin_auth` method. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return ``True`` if the specified key is a valid certificate authority key for the user being authenticated. This method may be called multiple times with different keys provided by the client. Applications should precompute as much as possible in the :meth:`begin_auth` method so that this function can quickly return whether the key provided is in the list. If blocking operations need to be performed to determine the validity of the key, this method may be defined as a coroutine. By default, this method returns ``False`` for all CA keys. .. note:: This function only needs to report whether the public key provided is a valid CA key for this user. If it is, AsyncSSH will verify that the certificate is valid, that the user is one of the valid principals for the certificate, and that the client possesses the private key corresponding to the public key in the certificate before allowing the authentication to succeed. :param str username: The user being authenticated :param key: The public key which signed the certificate sent by the client :type key: :class:`SSHKey` *public key* :returns: A bool indicating if the specified key is a valid CA key for the user being authenticated """ return False # pragma: no cover def password_auth_supported(self): """Return whether or not password authentication is supported This method should return ``True`` if password authentication is supported. Applications wishing to support it must have this method return ``True`` and implement :meth:`validate_password` to return whether or not the password provided by the client is valid for the user being authenticated. By default, this method returns ``False`` indicating that password authentication is not supported. :returns: A bool indicating if password authentication is supported or not """ return False # pragma: no cover def validate_password(self, username, password): """Return whether password is valid for this user This method should return ``True`` if the specified password is a valid password for the user being authenticated. It must be overridden by applications wishing to support password authentication. If the password provided is valid but expired, this method may raise :exc:`PasswordChangeRequired` to request that the client provide a new password before authentication is allowed to complete. In this case, the application must override :meth:`change_password` to handle the password change request. This method may be called multiple times with different passwords provided by the client. Applications may wish to limit the number of attempts which are allowed. This can be done by having :meth:`password_auth_supported` begin returning ``False`` after the maximum number of attempts is exceeded. If blocking operations need to be performed to determine the validity of the password, this method may be defined as a coroutine. By default, this method returns ``False`` for all passwords. :param str username: The user being authenticated :param str password: The password sent by the client :returns: A bool indicating if the specified password is valid for the user being authenticated :raises: :exc:`PasswordChangeRequired` if the password provided is expired and needs to be changed """ return False # pragma: no cover def change_password(self, username, old_password, new_password): """Handle a request to change a user's password This method is called when a user makes a request to change their password. It should first validate that the old password provided is correct and then attempt to change the user's password to the new value. If the old password provided is valid and the change to the new password is successful, this method should return ``True``. If the old password is not valid or password changes are not supported, it should return ``False``. It may also raise :exc:`PasswordChangeRequired` to request that the client try again if the new password is not acceptable for some reason. If blocking operations need to be performed to determine the validity of the old password or to change to the new password, this method may be defined as a coroutine. By default, this method returns ``False``, rejecting all password changes. :param str username: The user whose password should be changed :param str old_password: The user's current password :param str new_password: The new password being requested :returns: A bool indicating if the password change is successful or not :raises: :exc:`PasswordChangeRequired` if the new password is not acceptable and the client should be asked to provide another """ return False # pragma: no cover def kbdint_auth_supported(self): """Return whether or not keyboard-interactive authentication is supported This method should return ``True`` if keyboard-interactive authentication is supported. Applications wishing to support it must have this method return ``True`` and implement :meth:`get_kbdint_challenge` and :meth:`validate_kbdint_response` to generate the apporiate challenges and validate the responses for the user being authenticated. By default, this method returns ``NotImplemented`` tying this authentication to password authentication. If the application implements password authentication and this method is not overridden, keyboard-interactive authentication will be supported by prompting for a password and passing that to the password authentication callbacks. :returns: A bool indicating if keyboard-interactive authentication is supported or not """ return NotImplemented # pragma: no cover def get_kbdint_challenge(self, username, lang, submethods): """Return a keyboard-interactive auth challenge This method should return ``True`` if authentication should succeed without any challenge, ``False`` if authentication should fail without any challenge, or an auth challenge consisting of a challenge name, instructions, a language tag, and a list of tuples containing prompt strings and booleans indicating whether input should be echoed when a value is entered for that prompt. If blocking operations need to be performed to determine the challenge to issue, this method may be defined as a coroutine. :param str username: The user being authenticated :param str lang: The language requested by the client for the challenge :param str submethods: A comma-separated list of the types of challenges the client can support, or the empty string if the server should choose :returns: An authentication challenge as described above """ return False # pragma: no cover def validate_kbdint_response(self, username, responses): """Return whether the keyboard-interactive response is valid for this user This method should validate the keyboard-interactive responses provided and return ``True`` if authentication should succeed with no further challenge, ``False`` if authentication should fail, or an additional auth challenge in the same format returned by :meth:`get_kbdint_challenge`. Any series of challenges can be returned this way. To print a message in the middle of a sequence of challenges without prompting for additional data, a challenge can be returned with an empty list of prompts. After the client acknowledges this message, this function will be called again with an empty list of responses to continue the authentication. If blocking operations need to be performed to determine the validity of the response or the next challenge to issue, this method may be defined as a coroutine. :param str username: The user being authenticated :param responses: A list of responses to the last challenge :type responses: list of str :returns: ``True``, ``False``, or the next challenge """ return False # pragma: no cover def session_requested(self): """Handle an incoming session request This method is called when a session open request is received from the client, indicating it wishes to open a channel to be used for running a shell, executing a command, or connecting to a subsystem. If the application wishes to accept the session, it must override this method to return either an :class:`SSHServerSession` object to use to process the data received on the channel or a tuple consisting of an :class:`SSHServerChannel` object created with :meth:`create_server_channel ` and an :class:`SSHServerSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHServerSession` object can be returned instead of the session iself. This can be either returned directly or as a part of a tuple with an :class:`SSHServerChannel` object. To reject this request, this method should return ``False`` to send back a "Session refused" response or raise a :exc:`ChannelOpenError` exception with the reason for the failure. The details of what type of session the client wants to start will be delivered to methods on the :class:`SSHServerSession` object which is returned, along with other information such as environment variables, terminal type, size, and modes. By default, all session requests are rejected. :returns: One of the following: * An :class:`SSHServerSession` object or a coroutine which returns an :class:`SSHServerSession` * A tuple consisting of an :class:`SSHServerChannel` and the above * A callable or coroutine handler function which takes AsyncSSH stream objects for stdin, stdout, and stderr as arguments * A tuple consisting of an :class:`SSHServerChannel` and the above * ``False`` to refuse the request :raises: :exc:`ChannelOpenError` if the session shouldn't be accepted """ return False # pragma: no cover def connection_requested(self, dest_host, dest_port, orig_host, orig_port): """Handle a direct TCP/IP connection request This method is called when a direct TCP/IP connection request is received by the server. Applications wishing to accept such connections must override this method. To allow standard port forwarding of data on the connection to the requested destination host and port, this method should return ``True``. To reject this request, this method should return ``False`` to send back a "Connection refused" response or raise an :exc:`ChannelOpenError` exception with the reason for the failure. If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHTCPSession` object which can be used to process the data received on the channel or a tuple consisting of of an :class:`SSHTCPChannel` object created with :meth:`create_tcp_channel() ` and an :class:`SSHTCPSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHTCPSession` object can be returned instead of the session iself. This can be either returned directly or as a part of a tuple with an :class:`SSHTCPChannel` object. By default, all connection requests are rejected. :param str dest_host: The address the client wishes to connect to :param int dest_port: The port the client wishes to connect to :param str orig_host: The address the connection was originated from :param int orig_port: The port the connection was originated from :returns: One of the following: * An :class:`SSHTCPSession` object or a coroutine which returns an :class:`SSHTCPSession` * A tuple consisting of an :class:`SSHTCPChannel` and the above * A callable or coroutine handler function which takes AsyncSSH stream objects for reading from and writing to the connection * A tuple consisting of an :class:`SSHTCPChannel` and the above * ``True`` to request standard port forwarding * ``False`` to refuse the connection :raises: :exc:`ChannelOpenError` if the connection shouldn't be accepted """ return False # pragma: no cover def server_requested(self, listen_host, listen_port): """Handle a request to listen on a TCP/IP address and port This method is called when a client makes a request to listen on an address and port for incoming TCP connections. The port to listen on may be ``0`` to request a dynamically allocated port. Applications wishing to allow TCP/IP connection forwarding must override this method. To set up standard port forwarding of connections received on this address and port, this method should return ``True``. If the application wishes to manage listening for incoming connections itself, this method should return an :class:`SSHListener` object that listens for new connections and calls :meth:`create_connection ` on each of them to forward them back to the client or return ``None`` if the listener can't be set up. If blocking operations need to be performed to set up the listener, a coroutine which returns an :class:`SSHListener` can be returned instead of the listener itself. To reject this request, this method should return ``False``. By default, this method rejects all server requests. :param str listen_host: The address the server should listen on :param int listen_port: The port the server should listen on, or the value ``0`` to request that the server dynamically allocate a port :returns: One of the following: * An :class:`SSHListener` object or a coroutine which returns an :class:`SSHListener` or ``False`` if the listener can't be opened * ``True`` to set up standard port forwarding * ``False`` to reject the request """ return False # pragma: no cover def unix_connection_requested(self, dest_path): """Handle a direct UNIX domain socket connection request This method is called when a direct UNIX domain socket connection request is received by the server. Applications wishing to accept such connections must override this method. To allow standard path forwarding of data on the connection to the requested destination path, this method should return ``True``. To reject this request, this method should return ``False`` to send back a "Connection refused" response or raise an :exc:`ChannelOpenError` exception with the reason for the failure. If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHUNIXSession` object which can be used to process the data received on the channel or a tuple consisting of of an :class:`SSHUNIXChannel` object created with :meth:`create_unix_channel() ` and an :class:`SSHUNIXSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHUNIXSession` object can be returned instead of the session iself. This can be either returned directly or as a part of a tuple with an :class:`SSHUNIXChannel` object. By default, all connection requests are rejected. :param str dest_path: The path the client wishes to connect to :returns: One of the following: * An :class:`SSHUNIXSession` object or a coroutine which returns an :class:`SSHUNIXSession` * A tuple consisting of an :class:`SSHUNIXChannel` and the above * A callable or coroutine handler function which takes AsyncSSH stream objects for reading from and writing to the connection * A tuple consisting of an :class:`SSHUNIXChannel` and the above * ``True`` to request standard path forwarding * ``False`` to refuse the connection :raises: :exc:`ChannelOpenError` if the connection shouldn't be accepted """ return False # pragma: no cover def unix_server_requested(self, listen_path): """Handle a request to listen on a UNIX domain socket This method is called when a client makes a request to listen on a path for incoming UNIX domain socket connections. Applications wishing to allow UNIX domain socket forwarding must override this method. To set up standard path forwarding of connections received on this path, this method should return ``True``. If the application wishes to manage listening for incoming connections itself, this method should return an :class:`SSHListener` object that listens for new connections and calls :meth:`create_unix_connection ` on each of them to forward them back to the client or return ``None`` if the listener can't be set up. If blocking operations need to be performed to set up the listener, a coroutine which returns an :class:`SSHListener` can be returned instead of the listener itself. To reject this request, this method should return ``False``. By default, this method rejects all server requests. :param str listen_path: The path the server should listen on :returns: One of the following: * An :class:`SSHListener` object or a coroutine which returns an :class:`SSHListener` or ``False`` if the listener can't be opened * ``True`` to set up standard path forwarding * ``False`` to reject the request """ return False # pragma: no cover asyncssh-1.11.1/asyncssh/session.py000066400000000000000000000436271320320510200173160ustar00rootroot00000000000000# Copyright (c) 2013-2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH session handlers""" class SSHSession: """SSH session handler""" def connection_made(self, chan): """Called when a channel is opened successfully This method is called when a channel is opened successfully. The channel parameter should be stored if needed for later use. :param chan: The channel which was successfully opened. :type chan: :class:`SSHClientChannel` """ pass # pragma: no cover def connection_lost(self, exc): """Called when a channel is closed This method is called when a channel is closed. If the channel is shut down cleanly, *exc* will be ``None``. Otherwise, it will be an exception explaining the reason for the channel close. :param exc: The exception which caused the channel to close, or ``None`` if the channel closed cleanly. :type exc: :class:`Exception` """ pass # pragma: no cover def session_started(self): """Called when the session is started This method is called when a session has started up. For client and server sessions, this will be called once a shell, exec, or subsystem request has been successfully completed. For TCP and UNIX domain socket sessions, it will be called immediately after the connection is opened. """ pass # pragma: no cover def data_received(self, data, datatype): """Called when data is received on the channel This method is called when data is received on the channel. If an encoding was specified when the channel was created, the data will be delivered as a string after decoding with the requested encoding. Otherwise, the data will be delivered as bytes. :param data: The data received on the channel :param datatype: The extended data type of the data, from :ref:`extended data types ` :type data: str or bytes """ pass # pragma: no cover def eof_received(self): """Called when EOF is received on the channel This method is called when an end-of-file indication is received on the channel, after which no more data will be received. If this method returns ``True``, the channel remains half open and data may still be sent. Otherwise, the channel is automatically closed after this method returns. This is the default behavior for classes derived directly from :class:`SSHSession`, but not when using the higher-level streams API. Because input is buffered in that case, streaming sessions enable half-open channels to allow applications to respond to input read after an end-of-file indication is received. """ # pylint: disable=no-self-use return False # pragma: no cover def pause_writing(self): """Called when the write buffer becomes full This method is called when the channel's write buffer becomes full and no more data can be sent until the remote system adjusts its window. While data can still be buffered locally, applications may wish to stop producing new data until the write buffer has drained. """ pass # pragma: no cover def resume_writing(self): """Called when the write buffer has sufficiently drained This method is called when the channel's send window reopens and enough data has drained from the write buffer to allow the application to produce more data. """ pass # pragma: no cover class SSHClientSession(SSHSession): """SSH client session handler Applications should subclass this when implementing an SSH client session handler. The functions listed below should be implemented to define application-specific behavior. In particular, the standard ``asyncio`` protocol methods such as :meth:`connection_made`, :meth:`connection_lost`, :meth:`data_received`, :meth:`eof_received`, :meth:`pause_writing`, and :meth:`resume_writing` are all supported. In addition, :meth:`session_started` is called as soon as the SSH session is fully started, :meth:`xon_xoff_requested` can be used to determine if the server wants the client to support XON/XOFF flow control, and :meth:`exit_status_received` and :meth:`exit_signal_received` can be used to receive session exit information. """ def xon_xoff_requested(self, client_can_do): """XON/XOFF flow control has been enabled or disabled This method is called to notify the client whether or not to enable XON/XOFF flow control. If client_can_do is ``True`` and output is being sent to an interactive terminal the application should allow input of Control-S and Control-Q to pause and resume output, respectively. If client_can_do is ``False``, Control-S and Control-Q should be treated as normal input and passed through to the server. Non-interactive applications can ignore this request. By default, this message is ignored. :param bool client_can_do: Whether or not to enable XON/XOFF flow control """ pass # pragma: no cover def exit_status_received(self, status): """A remote exit status has been received for this session This method is called when the shell, command, or subsystem running on the server terminates and returns an exit status. A zero exit status generally means that the operation was successful. This call will generally be followed by a call to :meth:`connection_lost`. By default, the exit status is ignored. :param int status: The exit status returned by the remote process """ pass # pragma: no cover def exit_signal_received(self, signal, core_dumped, msg, lang): """A remote exit signal has been received for this session This method is called when the shell, command, or subsystem running on the server terminates abnormally with a signal. A more detailed error may also be provided, along with an indication of whether the remote process dumped core. This call will generally be followed by a call to :meth:`connection_lost`. By default, exit signals are ignored. :param str signal: The signal which caused the remote process to exit :param bool core_dumped: Whether or not the remote process dumped core :param msg: Details about what error occurred :param lang: The language the error message is in """ pass # pragma: no cover class SSHServerSession(SSHSession): """SSH server session handler Applications should subclass this when implementing an SSH server session handler. The functions listed below should be implemented to define application-specific behavior. In particular, the standard ``asyncio`` protocol methods such as :meth:`connection_made`, :meth:`connection_lost`, :meth:`data_received`, :meth:`eof_received`, :meth:`pause_writing`, and :meth:`resume_writing` are all supported. In addition, :meth:`pty_requested` is called when the client requests a pseudo-terminal, one of :meth:`shell_requested`, :meth:`exec_requested`, or :meth:`subsystem_requested` is called depending on what type of session the client wants to start, :meth:`session_started` is called once the SSH session is fully started, :meth:`terminal_size_changed` is called when the client's terminal size changes, :meth:`signal_received` is called when the client sends a signal, and :meth:`break_received` is called when the client sends a break. """ def pty_requested(self, term_type, term_size, term_modes): """A psuedo-terminal has been requested This method is called when the client sends a request to allocate a pseudo-terminal with the requested terminal type, size, and POSIX terminal modes. This method should return ``True`` if the request for the pseudo-terminal is accepted. Otherwise, it should return ``False`` to reject the request. By default, requests to allocate a pseudo-terminal are accepted but nothing is done with the associated terminal information. Applications wishing to use this information should implement this method and have it return ``True``, or call :meth:`get_terminal_type() `, :meth:`get_terminal_size() `, or :meth:`get_terminal_mode() ` on the :class:`SSHServerChannel` to get the information they need after a shell, command, or subsystem is started. :param str term_type: Terminal type to set for this session :param tuple term_size: Terminal size to set for this session provided as a tuple of four integers: the width and height of the terminal in characters followed by the width and height of the terminal in pixels :param dictionary term_modes: POSIX terminal modes to set for this session, where keys are taken from :ref:`POSIX terminal modes ` with values defined in section 8 of :rfc:`RFC 4254 <4254#section-8>`. :returns: A bool indicating if the request for a pseudo-terminal was allowed or not """ # pylint: disable=no-self-use,unused-argument return True # pragma: no cover def terminal_size_changed(self, width, height, pixwidth, pixheight): """The terminal size has changed This method is called when a client requests a pseudo-terminal and again whenever the the size of he client's terminal window changes. By default, this information is ignored, but applications wishing to use the terminal size can implement this method to get notified whenever it changes. :param int width: The width of the terminal in characters :param int height: The height of the terminal in characters :param int pixwidth: (optional) The width of the terminal in pixels :param int pixheight: (optional) The height of the terminal in pixels """ # pylint: disable=no-self-use,unused-argument pass # pragma: no cover def shell_requested(self): """The client has requested a shell This method should be implemented by the application to perform whatever processing is required when a client makes a request to open an interactive shell. It should return ``True`` to accept the request, or ``False`` to reject it. If the application returns ``True``, the :meth:`session_started` method will be called once the channel is fully open. No output should be sent until this method is called. By default this method returns ``False`` to reject all requests. :returns: A bool indicating if the shell request was allowed or not """ # pylint: disable=no-self-use,unused-argument return False # pragma: no cover def exec_requested(self, command): """The client has requested to execute a command This method should be implemented by the application to perform whatever processing is required when a client makes a request to execute a command. It should return ``True`` to accept the request, or ``False`` to reject it. If the application returns ``True``, the :meth:`session_started` method will be called once the channel is fully open. No output should be sent until this method is called. By default this method returns ``False`` to reject all requests. :param str command: The command the client has requested to execute :returns: A bool indicating if the exec request was allowed or not """ # pylint: disable=no-self-use,unused-argument return False # pragma: no cover def subsystem_requested(self, subsystem): """The client has requested to start a subsystem This method should be implemented by the application to perform whatever processing is required when a client makes a request to start a subsystem. It should return ``True`` to accept the request, or ``False`` to reject it. If the application returns ``True``, the :meth:`session_started` method will be called once the channel is fully open. No output should be sent until this method is called. By default this method returns ``False`` to reject all requests. :param str subsystem: The subsystem to start :returns: A bool indicating if the request to open the subsystem was allowed or not """ # pylint: disable=no-self-use,unused-argument return False # pragma: no cover def break_received(self, msec): """The client has sent a break This method is called when the client requests that the server perform a break operation on the terminal. If the break is performed, this method should return ``True``. Otherwise, it should return ``False``. By default, this method returns ``False`` indicating that no break was performed. :param int msec: The duration of the break in milliseconds :returns: A bool to indicate if the break operation was performed or not """ # pylint: disable=no-self-use,unused-argument return False # pragma: no cover def signal_received(self, signal): """The client has sent a signal This method is called when the client delivers a signal on the channel. By default, signals from the client are ignored. """ # pylint: disable=no-self-use,unused-argument pass # pragma: no cover class SSHTCPSession(SSHSession): """SSH TCP session handler Applications should subclass this when implementing a handler for SSH direct or forwarded TCP connections. SSH client applications wishing to open a direct connection should call :meth:`create_connection() ` on their :class:`SSHClientConnection`, passing in a factory which returns instances of this class. Server applications wishing to allow direct connections should implement the coroutine :meth:`connection_requested() ` on their :class:`SSHServer` object and have it return instances of this class. Server applications wishing to allow connection forwarding back to the client should implement the coroutine :meth:`server_requested() ` on their :class:`SSHServer` object and call :meth:`create_connection() ` on their :class:`SSHServerConnection` for each new connection, passing it a factory which returns instances of this class. When a connection is successfully opened, :meth:`session_started` will be called, after which the application can begin sending data. Received data will be passed to the :meth:`data_received` method. """ class SSHUNIXSession(SSHSession): """SSH UNIX domain socket session handler Applications should subclass this when implementing a handler for SSH direct or forwarded UNIX domain socket connections. SSH client applications wishing to open a direct connection should call :meth:`create_unix_connection() ` on their :class:`SSHClientConnection`, passing in a factory which returns instances of this class. Server applications wishing to allow direct connections should implement the coroutine :meth:`unix_connection_requested() ` on their :class:`SSHServer` object and have it return instances of this class. Server applications wishing to allow connection forwarding back to the client should implement the coroutine :meth:`unix_server_requested() ` on their :class:`SSHServer` object and call :meth:`create_unix_connection() ` on their :class:`SSHServerConnection` for each new connection, passing it a factory which returns instances of this class. When a connection is successfully opened, :meth:`session_started` will be called, after which the application can begin sending data. Received data will be passed to the :meth:`data_received` method. """ asyncssh-1.11.1/asyncssh/sftp.py000066400000000000000000004173061320320510200166060ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # Jonathan Slenders - proposed changes to allow SFTP server callbacks # to be coroutines """SFTP handlers""" import asyncio from collections import OrderedDict import errno from fnmatch import fnmatch import os from os import SEEK_SET, SEEK_CUR, SEEK_END import posixpath import stat import sys import time from .constants import DEFAULT_LANG from .constants import FXP_INIT, FXP_VERSION, FXP_OPEN, FXP_CLOSE, FXP_READ from .constants import FXP_WRITE, FXP_LSTAT, FXP_FSTAT, FXP_SETSTAT from .constants import FXP_FSETSTAT, FXP_OPENDIR, FXP_READDIR, FXP_REMOVE from .constants import FXP_MKDIR, FXP_RMDIR, FXP_REALPATH, FXP_STAT, FXP_RENAME from .constants import FXP_READLINK, FXP_SYMLINK, FXP_STATUS, FXP_HANDLE from .constants import FXP_DATA, FXP_NAME, FXP_ATTRS, FXP_EXTENDED from .constants import FXP_EXTENDED_REPLY from .constants import FXF_READ, FXF_WRITE, FXF_APPEND from .constants import FXF_CREAT, FXF_TRUNC, FXF_EXCL from .constants import FILEXFER_ATTR_SIZE, FILEXFER_ATTR_UIDGID from .constants import FILEXFER_ATTR_PERMISSIONS, FILEXFER_ATTR_ACMODTIME from .constants import FILEXFER_ATTR_EXTENDED, FILEXFER_ATTR_UNDEFINED from .constants import FX_OK, FX_EOF, FX_NO_SUCH_FILE, FX_PERMISSION_DENIED from .constants import FX_FAILURE, FX_BAD_MESSAGE, FX_NO_CONNECTION from .constants import FX_CONNECTION_LOST, FX_OP_UNSUPPORTED from .misc import async_context_manager, Error, Record from .packet import Byte, String, UInt32, UInt64, PacketDecodeError, SSHPacket SFTP_BLOCK_SIZE = 16384 _SFTP_VERSION = 3 _MAX_SFTP_REQUESTS = 128 _MAX_READDIR_NAMES = 128 _open_modes = { 'r': FXF_READ, 'w': FXF_WRITE | FXF_CREAT | FXF_TRUNC, 'a': FXF_WRITE | FXF_CREAT | FXF_APPEND, 'x': FXF_WRITE | FXF_CREAT | FXF_EXCL, 'r+': FXF_READ | FXF_WRITE, 'w+': FXF_READ | FXF_WRITE | FXF_CREAT | FXF_TRUNC, 'a+': FXF_READ | FXF_WRITE | FXF_CREAT | FXF_APPEND, 'x+': FXF_READ | FXF_WRITE | FXF_CREAT | FXF_EXCL } def _mode_to_pflags(mode): """Convert open mode to SFTP open flags""" if 'b' in mode: mode = mode.replace('b', '') binary = True else: binary = False pflags = _open_modes.get(mode) if not pflags: raise ValueError('Invalid mode: %r' % mode) return pflags, binary def _from_local_path(path): """Convert SFTP path to local path""" path = os.fsencode(path) if sys.platform == 'win32': # pragma: no cover path = path.replace(b'\\', b'/') if path[:1] != b'/' and path[1:2] == b':': path = b'/' + path return path def _to_local_path(path): """Convert local path to SFTP path""" if sys.platform == 'win32': # pragma: no cover path = os.fsdecode(path) if path[:1] == '/' and path[2:3] == ':': path = path[1:] path = path.replace('/', '\\') return path def _setstat(path, attrs): """Utility function to set file attributes""" if attrs.size is not None: os.truncate(path, attrs.size) if attrs.uid is not None and attrs.gid is not None: try: os.chown(path, attrs.uid, attrs.gid) except AttributeError: # pragma: no cover raise NotImplementedError if attrs.permissions is not None: os.chmod(path, stat.S_IMODE(attrs.permissions)) if attrs.atime is not None and attrs.mtime is not None: os.utime(path, times=(attrs.atime, attrs.mtime)) @asyncio.coroutine def _glob(fs, basedir, patlist, result): """Recursively match a glob pattern""" pattern, patlist = patlist[0], patlist[1:] names = yield from fs.listdir(basedir or b'.') for name in names: if pattern != name and name in (b'.', b'..'): continue if name[:1] == b'.' and not pattern[:1] == b'.': continue if fnmatch(name, pattern): if basedir: newbase = posixpath.join(basedir, name) else: newbase = name if not patlist: result.append(newbase) else: attrs = yield from fs.stat(newbase) if stat.S_ISDIR(attrs.permissions): yield from _glob(fs, newbase, patlist, result) @asyncio.coroutine def match_glob(fs, pattern, error_handler=None): """Match a glob pattern""" names = [] try: if any(c in pattern for c in b'*?'): patlist = pattern.split(b'/') if not patlist[0]: basedir = b'/' patlist = patlist[1:] else: basedir = None yield from _glob(fs, basedir, patlist, names) if not names: raise SFTPError(FX_NO_SUCH_FILE, 'No matches found') else: yield from fs.stat(pattern) names.append(pattern) except (OSError, SFTPError) as exc: # pylint: disable=attribute-defined-outside-init exc.srcpath = pattern if error_handler: error_handler(exc) else: raise exc return names class LocalFile: """A coroutine wrapper around local file I/O""" def __init__(self, f): self._file = f @classmethod def encode(cls, path): """Encode path name using filesystem native encoding This method has no effect if the path is already bytes. """ return os.fsencode(path) @classmethod def decode(cls, path): """Decode path name using filesystem native encoding This method has no effect if the path is already a string. """ return os.fsdecode(path) @classmethod def compose_path(cls, path, parent=None): """Compose a path If parent is not specified, just encode the path. """ return posixpath.join(parent, path) if parent else path @classmethod @asyncio.coroutine def open(cls, path, *args): """Open a local file""" return cls(open(_to_local_path(path), *args)) @classmethod @asyncio.coroutine def stat(cls, path): """Get attributes of a local file or directory, following symlinks""" return SFTPAttrs.from_local(os.stat(_to_local_path(path))) @classmethod @asyncio.coroutine def lstat(cls, path): """Get attributes of a local file, directory, or symlink""" return SFTPAttrs.from_local(os.lstat(_to_local_path(path))) @classmethod @asyncio.coroutine def setstat(cls, path, attrs): """Set attributes of a local file or directory""" _setstat(_to_local_path(path), attrs) @classmethod @asyncio.coroutine def exists(cls, path): """Return if the local path exists and isn't a broken symbolic link""" return os.path.exists(_to_local_path(path)) @classmethod @asyncio.coroutine def isdir(cls, path): """Return if the local path refers to a directory""" return os.path.isdir(_to_local_path(path)) @classmethod @asyncio.coroutine def listdir(cls, path): """Read the names of the files in a local directory""" files = os.listdir(_to_local_path(path)) if sys.platform == 'win32': # pragma: no cover files = [os.fsencode(f) for f in files] return files @classmethod @asyncio.coroutine def mkdir(cls, path): """Create a local directory with the specified attributes""" os.mkdir(_to_local_path(path)) @classmethod @asyncio.coroutine def readlink(cls, path): """Return the target of a local symbolic link""" return _from_local_path(os.readlink(_to_local_path(path))) @classmethod @asyncio.coroutine def symlink(cls, oldpath, newpath): """Create a local symbolic link""" os.symlink(_to_local_path(oldpath), _to_local_path(newpath)) @asyncio.coroutine def read(self, size, offset): """Read data from the local file""" self._file.seek(offset) return self._file.read(size) @asyncio.coroutine def write(self, data, offset): """Write data to the local file""" self._file.seek(offset) return self._file.write(data) @asyncio.coroutine def close(self): """Close the local file""" self._file.close() class _SFTPFileCopier: """SFTP file copier This class parforms an SFTP file copy, initiating multiple read and write requests to copy chunks of the file in parallel. """ def __init__(self, loop): self._loop = loop self._src = None self._dst = None self._block_size = 0 self._bytes_left = 0 self._offset = 0 self._pending = set() @asyncio.coroutine def _copy_block(self, offset, size): """Copy the next block of the file""" data = yield from self._src.read(size, offset) yield from self._dst.write(data, offset) return size def _copy_blocks(self): """Create parallel requests to copy blocks from one file to another""" while self._bytes_left and len(self._pending) < _MAX_SFTP_REQUESTS: size = min(self._bytes_left, self._block_size) task = asyncio.Task(self._copy_block(self._offset, size), loop=self._loop) self._pending.add(task) self._offset += size self._bytes_left -= size @asyncio.coroutine def copy(self, srcfs, dstfs, srcpath, dstpath, total_bytes, block_size, progress_handler): """Copy a file""" try: self._src = yield from srcfs.open(srcpath, 'rb') self._dst = yield from dstfs.open(dstpath, 'wb') self._block_size = block_size self._bytes_left = total_bytes self._copy_blocks() bytes_copied = 0 while self._pending: done, self._pending = yield from asyncio.wait( self._pending, return_when=asyncio.FIRST_COMPLETED) exceptions = [] for task in done: exc = task.exception() if exc: exceptions.append(exc) elif progress_handler: bytes_copied += task.result() progress_handler(srcpath, dstpath, bytes_copied, total_bytes) if exceptions: for task in self._pending: task.cancel() raise exceptions[0] self._copy_blocks() finally: if self._src: # pragma: no branch yield from self._src.close() if self._dst: # pragma: no branch yield from self._dst.close() class SFTPError(Error): """SFTP error This exception is raised when an error occurs while processing an SFTP request. Exception codes should be taken from :ref:`SFTP error codes `. :param int code: Disconnect reason, taken from :ref:`disconnect reason codes ` :param str reason: A human-readable reason for the disconnect :param str lang: The language the reason is in """ def __init__(self, code, reason, lang=DEFAULT_LANG): super().__init__('SFTP', code, reason, lang) class SFTPAttrs(Record): """SFTP file attributes SFTPAttrs is a simple record class with the following fields: ============ =========================================== ====== Field Description Type ============ =========================================== ====== size File size in bytes uint64 uid User id of file owner uint32 gid Group id of file owner uint32 permissions Bit mask of POSIX file permissions, uint32 atime Last access time, UNIX epoch seconds uint32 mtime Last modification time, UNIX epoch seconds uint32 ============ =========================================== ====== In addition to the above, an ``nlink`` field is provided which stores the number of links to this file, but it is not encoded in the SFTP protocol. It's included here only so that it can be used to create the default ``longname`` string in :class:`SFTPName` objects. Extended attributes can also be added via a field named ``extended`` which is a list of string name/value pairs. When setting attributes using an :class:`SFTPAttrs`, only fields which have been initialized will be changed on the selected file. """ # Unfortunately, pylint can't handle attributes defined with setattr # pylint: disable=attribute-defined-outside-init __slots__ = OrderedDict((('size', None), ('uid', None), ('gid', None), ('permissions', None), ('atime', None), ('mtime', None), ('nlink', None), ('extended', []))) def encode(self): """Encode SFTP attributes as bytes in an SSH packet""" flags = 0 attrs = [] if self.size is not None: flags |= FILEXFER_ATTR_SIZE attrs.append(UInt64(self.size)) if self.uid is not None and self.gid is not None: flags |= FILEXFER_ATTR_UIDGID attrs.append(UInt32(self.uid) + UInt32(self.gid)) if self.permissions is not None: flags |= FILEXFER_ATTR_PERMISSIONS attrs.append(UInt32(self.permissions)) if self.atime is not None and self.mtime is not None: flags |= FILEXFER_ATTR_ACMODTIME attrs.append(UInt32(int(self.atime)) + UInt32(int(self.mtime))) if self.extended: flags |= FILEXFER_ATTR_EXTENDED attrs.append(UInt32(len(self.extended))) attrs.extend(String(type) + String(data) for type, data in self.extended) return UInt32(flags) + b''.join(attrs) @classmethod def decode(cls, packet): """Decode bytes in an SSH packet as SFTP attributes""" flags = packet.get_uint32() attrs = cls() if flags & FILEXFER_ATTR_UNDEFINED: raise SFTPError(FX_BAD_MESSAGE, 'Unsupported attribute flags') if flags & FILEXFER_ATTR_SIZE: attrs.size = packet.get_uint64() if flags & FILEXFER_ATTR_UIDGID: attrs.uid = packet.get_uint32() attrs.gid = packet.get_uint32() if flags & FILEXFER_ATTR_PERMISSIONS: attrs.permissions = packet.get_uint32() if flags & FILEXFER_ATTR_ACMODTIME: attrs.atime = packet.get_uint32() attrs.mtime = packet.get_uint32() if flags & FILEXFER_ATTR_EXTENDED: count = packet.get_uint32() attrs.extended = [] for _ in range(count): attr = packet.get_string() data = packet.get_string() attrs.extended.append((attr, data)) return attrs @classmethod def from_local(cls, result): """Convert from local stat attributes""" return cls(result.st_size, result.st_uid, result.st_gid, result.st_mode, result.st_atime, result.st_mtime, result.st_nlink) class SFTPVFSAttrs(Record): """SFTP file system attributes SFTPVFSAttrs is a simple record class with the following fields: ============ =========================================== ====== Field Description Type ============ =========================================== ====== bsize File system block size (I/O size) uint64 frsize Fundamental block size (allocation size) uint64 blocks Total data blocks (in frsize units) uint64 bfree Free data blocks uint64 bavail Available data blocks (for non-root) uint64 files Total file inodes uint64 ffree Free file inodes uint64 favail Available file inodes (for non-root) uint64 fsid File system id uint64 flags File system flags (read-only, no-setuid) uint64 namemax Maximum filename length uint64 ============ =========================================== ====== """ # Unfortunately, pylint can't handle attributes defined with setattr # pylint: disable=attribute-defined-outside-init __slots__ = OrderedDict((('bsize', 0), ('frsize', 0), ('blocks', 0), ('bfree', 0), ('bavail', 0), ('files', 0), ('ffree', 0), ('favail', 0), ('fsid', 0), ('flags', 0), ('namemax', 0))) def encode(self): """Encode SFTP statvfs attributes as bytes in an SSH packet""" return b''.join((UInt64(self.bsize), UInt64(self.frsize), UInt64(self.blocks), UInt64(self.bfree), UInt64(self.bavail), UInt64(self.files), UInt64(self.ffree), UInt64(self.favail), UInt64(self.fsid), UInt64(self.flags), UInt64(self.namemax))) @classmethod def decode(cls, packet): """Decode bytes in an SSH packet as SFTP statvfs attributes""" vfsattrs = cls() vfsattrs.bsize = packet.get_uint64() vfsattrs.frsize = packet.get_uint64() vfsattrs.blocks = packet.get_uint64() vfsattrs.bfree = packet.get_uint64() vfsattrs.bavail = packet.get_uint64() vfsattrs.files = packet.get_uint64() vfsattrs.ffree = packet.get_uint64() vfsattrs.favail = packet.get_uint64() vfsattrs.fsid = packet.get_uint64() vfsattrs.flags = packet.get_uint64() vfsattrs.namemax = packet.get_uint64() return vfsattrs @classmethod def from_local(cls, result): """Convert from local statvfs attributes""" return cls(result.f_bsize, result.f_frsize, result.f_blocks, result.f_bfree, result.f_bavail, result.f_files, result.f_ffree, result.f_favail, 0, result.f_flag, result.f_namemax) class SFTPName(Record): """SFTP file name and attributes SFTPName is a simple record class with the following fields: ========= ================================== ================== Field Description Type ========= ================================== ================== filename Filename str or bytes longname Expanded form of filename & attrs str or bytes attrs File attributes :class:`SFTPAttrs` ========= ================================== ================== A list of these is returned by :meth:`readdir() ` in :class:`SFTPClient` when retrieving the contents of a directory. """ __slots__ = OrderedDict((('filename', ''), ('longname', ''), ('attrs', SFTPAttrs()))) def encode(self): """Encode an SFTP name as bytes in an SSH packet""" # pylint: disable=no-member return (String(self.filename) + String(self.longname) + self.attrs.encode()) @classmethod def decode(cls, packet): """Decode bytes in an SSH packet as an SFTP name""" filename = packet.get_string() longname = packet.get_string() attrs = SFTPAttrs.decode(packet) return cls(filename, longname, attrs) class SFTPHandler: """SFTP session handler""" # SFTP implementations with broken order for SYMLINK arguments _nonstandard_symlink_impls = ['OpenSSH', 'paramiko'] # Return types by message -- unlisted entries always return FXP_STATUS, # those below return FXP_STATUS on error _return_types = { FXP_OPEN: FXP_HANDLE, FXP_READ: FXP_DATA, FXP_LSTAT: FXP_ATTRS, FXP_FSTAT: FXP_ATTRS, FXP_OPENDIR: FXP_HANDLE, FXP_READDIR: FXP_NAME, FXP_REALPATH: FXP_NAME, FXP_STAT: FXP_ATTRS, FXP_READLINK: FXP_NAME, b'statvfs@openssh.com': FXP_EXTENDED_REPLY, b'fstatvfs@openssh.com': FXP_EXTENDED_REPLY } def __init__(self, reader, writer): self._reader = reader self._writer = writer @asyncio.coroutine def _cleanup(self, exc): """Clean up this SFTP session""" # pylint: disable=unused-argument if self._writer: # pragma: no branch self._writer.close() self._reader = None self._writer = None @asyncio.coroutine def _process_packet(self, pkttype, pktid, packet): """Abstract method for processing SFTP packets""" raise NotImplementedError def send_packet(self, *args): """Send an SFTP packet""" payload = b''.join(args) try: self._writer.write(UInt32(len(payload)) + payload) except ConnectionError as exc: raise SFTPError(FX_CONNECTION_LOST, str(exc)) from None @asyncio.coroutine def recv_packet(self): """Receive an SFTP packet""" try: pktlen = yield from self._reader.readexactly(4) pktlen = int.from_bytes(pktlen, 'big') packet = yield from self._reader.readexactly(pktlen) packet = SSHPacket(packet) except EOFError: raise SFTPError(FX_CONNECTION_LOST, 'Channel closed') from None return packet @asyncio.coroutine def recv_packets(self): """Receive and process SFTP packets""" try: while self._reader: # pragma: no branch packet = yield from self.recv_packet() pkttype = packet.get_byte() pktid = packet.get_uint32() yield from self._process_packet(pkttype, pktid, packet) except PacketDecodeError as exc: yield from self._cleanup(SFTPError(FX_BAD_MESSAGE, str(exc))) except (OSError, SFTPError) as exc: yield from self._cleanup(exc) class SFTPClientHandler(SFTPHandler): """An SFTP client session handler""" _extensions = [] def __init__(self, loop, reader, writer): super().__init__(reader, writer) self._loop = loop self._version = None self._next_pktid = 0 self._requests = {} self._nonstandard_symlink = False self._supports_posix_rename = False self._supports_statvfs = False self._supports_fstatvfs = False self._supports_hardlink = False self._supports_fsync = False @asyncio.coroutine def _cleanup(self, exc): """Clean up this SFTP client session""" for waiter in self._requests.values(): if waiter and not waiter.cancelled(): waiter.set_exception(exc) self._requests = {} yield from super()._cleanup(exc) @asyncio.coroutine def _process_packet(self, pkttype, pktid, packet): """Process incoming SFTP responses""" try: waiter = self._requests.pop(pktid) except KeyError: yield from self._cleanup(SFTPError(FX_BAD_MESSAGE, 'Invalid response id')) else: if waiter and not waiter.cancelled(): waiter.set_result((pkttype, packet)) def _send_request(self, pkttype, *args, waiter=None): """Send an SFTP request""" if not self._writer: raise SFTPError(FX_NO_CONNECTION, 'Connection not open') pktid = self._next_pktid self._next_pktid = (self._next_pktid + 1) & 0xffffffff self._requests[pktid] = waiter if isinstance(pkttype, bytes): hdr = Byte(FXP_EXTENDED) + UInt32(pktid) + String(pkttype) else: hdr = Byte(pkttype) + UInt32(pktid) self.send_packet(hdr, *args) @asyncio.coroutine def _make_request(self, pkttype, *args): """Make an SFTP request and wait for a response""" waiter = asyncio.Future(loop=self._loop) self._send_request(pkttype, *args, waiter=waiter) resptype, resp = yield from waiter return_type = self._return_types.get(pkttype) if resptype not in (FXP_STATUS, return_type): raise SFTPError(FX_BAD_MESSAGE, 'Unexpected response type: %s' % resptype) result = self._packet_handlers[resptype](self, resp) if result is not None or return_type is None: return result else: raise SFTPError(FX_BAD_MESSAGE, 'Unexpected FX_OK response') def _process_status(self, packet): """Process an incoming SFTP status response""" # pylint: disable=no-self-use code = packet.get_uint32() try: reason = packet.get_string().decode('utf-8') lang = packet.get_string().decode('ascii') except UnicodeDecodeError: raise SFTPError(FX_BAD_MESSAGE, 'Invalid status message') from None packet.check_end() if code == FX_OK: return None else: raise SFTPError(code, reason, lang) def _process_handle(self, packet): """Process an incoming SFTP handle response""" # pylint: disable=no-self-use handle = packet.get_string() packet.check_end() return handle def _process_data(self, packet): """Process an incoming SFTP data response""" # pylint: disable=no-self-use data = packet.get_string() packet.check_end() return data def _process_name(self, packet): """Process an incoming SFTP name response""" # pylint: disable=no-self-use count = packet.get_uint32() names = [SFTPName.decode(packet) for i in range(count)] packet.check_end() return names def _process_attrs(self, packet): """Process an incoming SFTP attributes response""" # pylint: disable=no-self-use attrs = SFTPAttrs().decode(packet) packet.check_end() return attrs def _process_extended_reply(self, packet): """Process an incoming SFTP extended reply response""" # pylint: disable=no-self-use # Let the caller do the decoding for extended replies return packet _packet_handlers = { FXP_STATUS: _process_status, FXP_HANDLE: _process_handle, FXP_DATA: _process_data, FXP_NAME: _process_name, FXP_ATTRS: _process_attrs, FXP_EXTENDED_REPLY: _process_extended_reply } @asyncio.coroutine def start(self): """Start an SFTP client""" extensions = (String(name) + String(data) for name, data in self._extensions) self.send_packet(Byte(FXP_INIT), UInt32(_SFTP_VERSION), *extensions) resp = yield from self.recv_packet() resptype = resp.get_byte() if resptype != FXP_VERSION: raise SFTPError(FX_BAD_MESSAGE, 'Expected version message') version = resp.get_uint32() extensions = [] while resp: name = resp.get_string() data = resp.get_string() extensions.append((name, data)) if version != _SFTP_VERSION: raise SFTPError(FX_BAD_MESSAGE, 'Unsupported version: %d' % version) self._version = version for name, data in extensions: if name == b'posix-rename@openssh.com' and data == b'1': self._supports_posix_rename = True elif name == b'statvfs@openssh.com' and data == b'2': self._supports_statvfs = True elif name == b'fstatvfs@openssh.com' and data == b'2': self._supports_fstatvfs = True elif name == b'hardlink@openssh.com' and data == b'1': self._supports_hardlink = True elif name == b'fsync@openssh.com' and data == b'1': self._supports_fsync = True if version == 3: # Check if the server has a buggy SYMLINK implementation server_version = self._reader.get_extra_info('server_version', '') if any(name in server_version for name in self._nonstandard_symlink_impls): self._nonstandard_symlink = True @asyncio.coroutine def open(self, filename, pflags, attrs): """Make an SFTP open request""" return (yield from self._make_request(FXP_OPEN, String(filename), UInt32(pflags), attrs.encode())) @asyncio.coroutine def close(self, handle): """Make an SFTP close request""" return (yield from self._make_request(FXP_CLOSE, String(handle))) def nonblocking_close(self, handle): """Send an SFTP close request without blocking on the response""" # Used by context managers, since they can't block to wait for a reply self._send_request(FXP_CLOSE, String(handle)) @asyncio.coroutine def read(self, handle, offset, length): """Make an SFTP read request""" return (yield from self._make_request(FXP_READ, String(handle), UInt64(offset), UInt32(length))) @asyncio.coroutine def write(self, handle, offset, data): """Make an SFTP write request""" return (yield from self._make_request(FXP_WRITE, String(handle), UInt64(offset), String(data))) @asyncio.coroutine def stat(self, path): """Make an SFTP stat request""" return (yield from self._make_request(FXP_STAT, String(path))) @asyncio.coroutine def lstat(self, path): """Make an SFTP lstat request""" return (yield from self._make_request(FXP_LSTAT, String(path))) @asyncio.coroutine def fstat(self, handle): """Make an SFTP fstat request""" return (yield from self._make_request(FXP_FSTAT, String(handle))) @asyncio.coroutine def setstat(self, path, attrs): """Make an SFTP setstat request""" return (yield from self._make_request(FXP_SETSTAT, String(path), attrs.encode())) @asyncio.coroutine def fsetstat(self, handle, attrs): """Make an SFTP fsetstat request""" return (yield from self._make_request(FXP_FSETSTAT, String(handle), attrs.encode())) @asyncio.coroutine def statvfs(self, path): """Make an SFTP statvfs request""" if self._supports_statvfs: packet = yield from self._make_request(b'statvfs@openssh.com', String(path)) vfsattrs = SFTPVFSAttrs.decode(packet) packet.check_end() return vfsattrs else: raise SFTPError(FX_OP_UNSUPPORTED, 'statvfs not supported') @asyncio.coroutine def fstatvfs(self, handle): """Make an SFTP fstatvfs request""" if self._supports_fstatvfs: packet = yield from self._make_request(b'fstatvfs@openssh.com', String(handle)) vfsattrs = SFTPVFSAttrs.decode(packet) packet.check_end() return vfsattrs else: raise SFTPError(FX_OP_UNSUPPORTED, 'fstatvfs not supported') @asyncio.coroutine def remove(self, path): """Make an SFTP remove request""" return (yield from self._make_request(FXP_REMOVE, String(path))) @asyncio.coroutine def rename(self, oldpath, newpath): """Make an SFTP rename request""" return (yield from self._make_request(FXP_RENAME, String(oldpath), String(newpath))) @asyncio.coroutine def posix_rename(self, oldpath, newpath): """Make an SFTP POSIX rename request""" if self._supports_posix_rename: return (yield from self._make_request(b'posix-rename@openssh.com', String(oldpath), String(newpath))) else: raise SFTPError(FX_OP_UNSUPPORTED, 'POSIX rename not supported') @asyncio.coroutine def opendir(self, path): """Make an SFTP opendir request""" return (yield from self._make_request(FXP_OPENDIR, String(path))) @asyncio.coroutine def readdir(self, handle): """Make an SFTP readdir request""" return (yield from self._make_request(FXP_READDIR, String(handle))) @asyncio.coroutine def mkdir(self, path, attrs): """Make an SFTP mkdir request""" return (yield from self._make_request(FXP_MKDIR, String(path), attrs.encode())) @asyncio.coroutine def rmdir(self, path): """Make an SFTP rmdir request""" return (yield from self._make_request(FXP_RMDIR, String(path))) @asyncio.coroutine def realpath(self, path): """Make an SFTP realpath request""" return (yield from self._make_request(FXP_REALPATH, String(path))) @asyncio.coroutine def readlink(self, path): """Make an SFTP readlink request""" return (yield from self._make_request(FXP_READLINK, String(path))) @asyncio.coroutine def symlink(self, oldpath, newpath): """Make an SFTP symlink request""" if self._nonstandard_symlink: args = String(oldpath) + String(newpath) else: args = String(newpath) + String(oldpath) return (yield from self._make_request(FXP_SYMLINK, args)) @asyncio.coroutine def link(self, oldpath, newpath): """Make an SFTP link request""" if self._supports_hardlink: return (yield from self._make_request(b'hardlink@openssh.com', String(oldpath), String(newpath))) else: raise SFTPError(FX_OP_UNSUPPORTED, 'link not supported') @asyncio.coroutine def fsync(self, handle): """Make an SFTP fsync request""" if self._supports_fsync: return (yield from self._make_request(b'fsync@openssh.com', String(handle))) else: raise SFTPError(FX_OP_UNSUPPORTED, 'fsync not supported') def exit(self): """Handle a request to close the SFTP session""" if self._writer: self._writer.write_eof() @asyncio.coroutine def wait_closed(self): """Wait for this SFTP session to close""" if self._writer: yield from self._writer.channel.wait_closed() class SFTPClientFile: """SFTP client remote file object This class represents an open file on a remote SFTP server. It is opened with the :meth:`open() ` method on the :class:`SFTPClient` class and provides methods to read and write data and get and set attributes on the open file. """ def __init__(self, handler, handle, appending, encoding, errors): self._handler = handler self._handle = handle self._appending = appending self._encoding = encoding self._errors = errors self._offset = None if appending else 0 def __enter__(self): """Allow SFTPClientFile to be used as a context manager""" return self def __exit__(self, *exc_info): """Automatically close the file when used as a context manager""" if self._handle: self._handler.nonblocking_close(self._handle) self._handle = None @asyncio.coroutine def __aenter__(self): """Allow SFTPClientFile to be used as an async context manager""" return self @asyncio.coroutine def __aexit__(self, *exc_info): """Wait for file close when used as an async context manager""" yield from self.close() @asyncio.coroutine def _end(self): """Return the offset of the end of the file""" attrs = yield from self.stat() return attrs.size @asyncio.coroutine def read(self, size=-1, offset=None): """Read data from the remote file This method reads and returns up to ``size`` bytes of data from the remote file. If size is negative, all data up to the end of the file is returned. If offset is specified, the read will be performed starting at that offset rather than the current file position. This argument should be provided if you want to issue parallel reads on the same file, since the file position is not predictable in that case. Data will be returned as a string if an encoding was set when the file was opened. Otherwise, data is returned as bytes. An empty string or bytes object is returned when at EOF. :param int size: The number of bytes to read :param int offset: (optional) The offset from the beginning of the file to begin reading :returns: data read from the file, as a str or bytes :raises: | :exc:`ValueError` if the file has been closed | :exc:`UnicodeDecodeError` if the data can't be decoded using the requested encoding | :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') if offset is None: offset = self._offset if offset is None: # We're appending and haven't seeked backward in the file # since the last write, so there's no data to return data = b'' elif size is None or size < 0: data = [] try: while True: result = yield from self._handler.read(self._handle, offset, SFTP_BLOCK_SIZE) data.append(result) offset += len(result) self._offset = offset except SFTPError as exc: if exc.code != FX_EOF: raise data = b''.join(data) else: data = b'' try: data = yield from self._handler.read(self._handle, offset, size) self._offset = offset + len(data) except SFTPError as exc: if exc.code != FX_EOF: raise if self._encoding: data = data.decode(self._encoding, self._errors) return data @asyncio.coroutine def write(self, data, offset=None): """Write data to the remote file This method writes the specified data at the current position in the remote file. :param data: The data to write to the file :param int offset: (optional) The offset from the beginning of the file to begin writing :type data: str or bytes If offset is specified, the write will be performed starting at that offset rather than the current file position. This argument should be provided if you want to issue parallel writes on the same file, since the file position is not predictable in that case. :returns: number of bytes written :raises: | :exc:`ValueError` if the file has been closed | :exc:`UnicodeEncodeError` if the data can't be encoded using the requested encoding | :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') if offset is None: # Offset is ignored when appending, so fill in an offset of 0 # if we don't have a current file position offset = self._offset or 0 if self._encoding: data = data.encode(self._encoding, self._errors) yield from self._handler.write(self._handle, offset, data) self._offset = None if self._appending else offset + len(data) return len(data) @asyncio.coroutine def seek(self, offset, from_what=SEEK_SET): """Seek to a new position in the remote file This method changes the position in the remote file. The ``offset`` passed in is treated as relative to the beginning of the file if ``from_what`` is set to ``SEEK_SET`` (the default), relative to the current file position if it is set to ``SEEK_CUR``, or relative to the end of the file if it is set to ``SEEK_END``. :param int offset: The amount to seek :param int from_what: (optional) The reference point to use (SEEK_SET, SEEK_CUR, or SEEK_END) :returns: The new byte offset from the beginning of the file """ if self._handle is None: raise ValueError('I/O operation on closed file') if from_what == SEEK_SET: self._offset = offset elif from_what == SEEK_CUR: self._offset += offset elif from_what == SEEK_END: self._offset = (yield from self._end()) + offset else: raise ValueError('Invalid reference point') return self._offset @asyncio.coroutine def tell(self): """Return the current position in the remote file This method returns the current position in the remote file. :returns: The current byte offset from the beginning of the file """ if self._handle is None: raise ValueError('I/O operation on closed file') if self._offset is None: self._offset = yield from self._end() return self._offset @asyncio.coroutine def stat(self): """Return file attributes of the remote file This method queries file attributes of the currently open file. :returns: An :class:`SFTPAttrs` containing the file attributes :raises: :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') return (yield from self._handler.fstat(self._handle)) @asyncio.coroutine def setstat(self, attrs): """Set attributes of the remote file This method sets file attributes of the currently open file. :param attrs: File attributes to set on the file :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') yield from self._handler.fsetstat(self._handle, attrs) @asyncio.coroutine def statvfs(self): """Return file system attributes of the remote file This method queries attributes of the file system containing the currently open file. :returns: An :class:`SFTPVFSAttrs` containing the file system attributes :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') return (yield from self._handler.fstatvfs(self._handle)) @asyncio.coroutine def truncate(self, size=None): """Truncate the remote file to the specified size This method changes the remote file's size to the specified value. If a size is not provided, the current file position is used. :param int size: (optional) The desired size of the file, in bytes :raises: :exc:`SFTPError` if the server returns an error """ if size is None: size = self._offset yield from self.setstat(SFTPAttrs(size=size)) @asyncio.coroutine def chown(self, uid, gid): """Change the owner user and group id of the remote file This method changes the user and group id of the currently open file. :param int uid: The new user id to assign to the file :param int gid: The new group id to assign to the file :raises: :exc:`SFTPError` if the server returns an error """ yield from self.setstat(SFTPAttrs(uid=uid, gid=gid)) @asyncio.coroutine def chmod(self, mode): """Change the file permissions of the remote file This method changes the permissions of the currently open file. :param int mode: The new file permissions, expressed as an int :raises: :exc:`SFTPError` if the server returns an error """ yield from self.setstat(SFTPAttrs(permissions=mode)) @asyncio.coroutine def utime(self, times=None): """Change the access and modify times of the remote file This method changes the access and modify times of the currently open file. If ``times`` is not provided, the times will be changed to the current time. :param times: (optional) The new access and modify times, as seconds relative to the UNIX epoch :type times: tuple of two int or float values :raises: :exc:`SFTPError` if the server returns an error """ # pylint: disable=unpacking-non-sequence if times is None: atime = mtime = time.time() else: atime, mtime = times yield from self.setstat(SFTPAttrs(atime=atime, mtime=mtime)) @asyncio.coroutine def fsync(self): """Force the remote file data to be written to disk""" if self._handle is None: raise ValueError('I/O operation on closed file') yield from self._handler.fsync(self._handle) @asyncio.coroutine def close(self): """Close the remote file""" if self._handle: yield from self._handler.close(self._handle) self._handle = None class SFTPClient: """SFTP client This class represents the client side of an SFTP session. It is started by calling the :meth:`start_sftp_client() ` method on the :class:`SSHClientConnection` class. """ def __init__(self, loop, handler, path_encoding, path_errors): self._loop = loop self._handler = handler self._path_encoding = path_encoding self._path_errors = path_errors self._cwd = None def __enter__(self): """Allow SFTPClient to be used as a context manager""" return self def __exit__(self, *exc_info): """Automatically close the session when used as a context manager""" self.exit() @asyncio.coroutine def __aenter__(self): """Allow SFTPClient to be used as an async context manager""" return self @asyncio.coroutine def __aexit__(self, *exc_info): """Wait for client close when used as an async context manager""" self.__exit__() yield from self.wait_closed() def encode(self, path): """Encode path name using configured path encoding This method has no effect if the path is already bytes. """ if isinstance(path, str): if self._path_encoding: path = path.encode(self._path_encoding, self._path_errors) else: raise SFTPError(FX_BAD_MESSAGE, 'Path must be bytes when ' 'encoding is not set') return path def decode(self, path, want_string=True): """Decode path name using configured path encoding This method has no effect if want_string is set to ``False``. """ if want_string and self._path_encoding: try: path = path.decode(self._path_encoding, self._path_errors) except UnicodeDecodeError: raise SFTPError(FX_BAD_MESSAGE, 'Unable to decode name') from None return path def compose_path(self, path, parent=...): """Compose a path If parent is not specified, return a path relative to the current remote working directory. """ if parent is ...: parent = self._cwd path = self.encode(path) return posixpath.join(parent, path) if parent else path @asyncio.coroutine def _mode(self, path, statfunc=None): """Return the mode of a remote path, or 0 if it can't be accessed""" if statfunc is None: statfunc = self.stat try: return (yield from statfunc(path)).permissions except SFTPError as exc: if exc.code in (FX_NO_SUCH_FILE, FX_PERMISSION_DENIED): return 0 else: raise @asyncio.coroutine def _glob(self, fs, patterns, error_handler): """Begin a new glob pattern match""" # pylint: disable=no-self-use if isinstance(patterns, (str, bytes)): patterns = [patterns] result = [] for pattern in patterns: if not pattern: continue names = yield from match_glob(fs, fs.encode(pattern), error_handler) if isinstance(pattern, str): names = [fs.decode(name) for name in names] result.extend(names) return result @asyncio.coroutine def _copy(self, srcfs, dstfs, srcpath, dstpath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler): """Copy a file, directory, or symbolic link""" if follow_symlinks: srcattrs = yield from srcfs.stat(srcpath) else: srcattrs = yield from srcfs.lstat(srcpath) try: if stat.S_ISDIR(srcattrs.permissions): if not recurse: raise SFTPError(FX_FAILURE, '%s is a directory' % srcpath.decode('utf-8', errors='replace')) if not (yield from dstfs.isdir(dstpath)): yield from dstfs.mkdir(dstpath) names = yield from srcfs.listdir(srcpath) for name in names: if name in (b'.', b'..'): continue srcfile = posixpath.join(srcpath, name) dstfile = posixpath.join(dstpath, name) yield from self._copy(srcfs, dstfs, srcfile, dstfile, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) elif stat.S_ISLNK(srcattrs.permissions): targetpath = yield from srcfs.readlink(srcpath) yield from dstfs.symlink(targetpath, dstpath) else: yield from _SFTPFileCopier(self._loop).copy( srcfs, dstfs, srcpath, dstpath, srcattrs.size, block_size, progress_handler) if preserve: srcattrs = yield from srcfs.stat(srcpath) yield from dstfs.setstat( dstpath, SFTPAttrs(permissions=srcattrs.permissions, atime=srcattrs.atime, mtime=srcattrs.mtime)) except (OSError, SFTPError) as exc: # pylint: disable=attribute-defined-outside-init exc.srcpath = srcpath exc.dstpath = dstpath if error_handler: error_handler(exc) else: raise @asyncio.coroutine def _begin_copy(self, srcfs, dstfs, srcpaths, dstpath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler): """Begin a new file upload, download, or copy""" dst_isdir = dstpath is None or (yield from dstfs.isdir(dstpath)) if dstpath: dstpath = dstfs.encode(dstpath) if isinstance(srcpaths, (str, bytes)): srcpaths = [srcpaths] elif not dst_isdir: raise SFTPError(FX_FAILURE, '%s must be a directory' % dstpath.decode('utf-8', errors='replace')) for srcfile in srcpaths: srcfile = srcfs.encode(srcfile) filename = posixpath.basename(srcfile) if dstpath is None: dstfile = filename elif dst_isdir: dstfile = dstfs.compose_path(filename, parent=dstpath) else: dstfile = dstpath yield from self._copy(srcfs, dstfs, srcfile, dstfile, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) @asyncio.coroutine def get(self, remotepaths, localpath=None, *, preserve=False, recurse=False, follow_symlinks=False, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): """Download remote files This method downloads one or more files or directories from the remote system. Either a single remote path or a sequence of remote paths to download can be provided. When downloading a single file or directory, the local path can be either the full path to download data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the remote path will be used as the local name. When downloading multiple files, the local path must refer to an existing directory. If no local path is provided, the file is downloaded into the current local working directory. If preserve is ``True``, the access and modification times and permissions of the original file are set on the downloaded file. If recurse is ``True`` and the remote path points at a directory, the entire subtree under that directory is downloaded. If follow_symlinks is set to ``True``, symbolic links found on the remote system will have the contents of their target downloaded rather than creating a local symbolic link. When using this option during a recursive download, one needs to watch out for links that result in loops. The block_size value controls the size of read and write operations issued to download the files. It defaults to 16 KB. If progress_handler is specified, it will be called after each block of a file is successfully downloaded. The arguments passed to this handler will be the source path, destination path, bytes downloaded so far, and total bytes in the file being downloaded. If multiple source paths are provided or recurse is set to ``True``, the progress_handler will be called consecutively on each file being downloaded. If error_handler is specified and an error occurs during the download, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple remote paths are provided or when recurse is set to ``True``, to allow error information to be collected without aborting the download of the remaining files. The error handler can raise an exception if it wants the download to completely stop. Otherwise, after an error, the download will continue starting with the next file. :param remotepaths: The paths of the remote files or directories to download :param str localpath: (optional) The path of the local file or directory to download into :param bool preserve: (optional) Whether or not to preserve the original file attributes :param bool recurse: (optional) Whether or not to recursively copy directories :param bool follow_symlinks: (optional) Whether or not to follow symbolic links :param int block_size: (optional) The block size to use for file reads and writes :param callable progress_handler: (optional) The function to call to report download progress :param callable error_handler: (optional) The function to call when an error occurs :type remotepaths: str or bytes, or a sequence of these :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error """ yield from self._begin_copy(self, LocalFile, remotepaths, localpath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) @asyncio.coroutine def put(self, localpaths, remotepath=None, *, preserve=False, recurse=False, follow_symlinks=False, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): """Upload local files This method uploads one or more files or directories to the remote system. Either a single local path or a sequence of local paths to upload can be provided. When uploading a single file or directory, the remote path can be either the full path to upload data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the local path will be used as the remote name. When uploading multiple files, the remote path must refer to an existing directory. If no remote path is provided, the file is uploaded into the current remote working directory. If preserve is ``True``, the access and modification times and permissions of the original file are set on the uploaded file. If recurse is ``True`` and the local path points at a directory, the entire subtree under that directory is uploaded. If follow_symlinks is set to ``True``, symbolic links found on the local system will have the contents of their target uploaded rather than creating a remote symbolic link. When using this option during a recursive upload, one needs to watch out for links that result in loops. The block_size value controls the size of read and write operations issued to upload the files. It defaults to 16 KB. If progress_handler is specified, it will be called after each block of a file is successfully uploaded. The arguments passed to this handler will be the source path, destination path, bytes uploaded so far, and total bytes in the file being uploaded. If multiple source paths are provided or recurse is set to ``True``, the progress_handler will be called consecutively on each file being uploaded. If error_handler is specified and an error occurs during the upload, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple local paths are provided or when recurse is set to ``True``, to allow error information to be collected without aborting the upload of the remaining files. The error handler can raise an exception if it wants the upload to completely stop. Otherwise, after an error, the upload will continue starting with the next file. :param localpaths: The paths of the local files or directories to upload :param remotepath: (optional) The path of the remote file or directory to upload into :param bool preserve: (optional) Whether or not to preserve the original file attributes :param bool recurse: (optional) Whether or not to recursively copy directories :param bool follow_symlinks: (optional) Whether or not to follow symbolic links :param int block_size: (optional) The block size to use for file reads and writes :param callable progress_handler: (optional) The function to call to report upload progress :param callable error_handler: (optional) The function to call when an error occurs :type localpaths: str or bytes, or a sequence of these :type remotepath: str or bytes :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error """ yield from self._begin_copy(LocalFile, self, localpaths, remotepath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) @asyncio.coroutine def copy(self, srcpaths, dstpath=None, *, preserve=False, recurse=False, follow_symlinks=False, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): """Copy remote files to a new location This method copies one or more files or directories on the remote system to a new location. Either a single source path or a sequence of source paths to copy can be provided. When copying a single file or directory, the destination path can be either the full path to copy data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the source path will be used as the destination name. When copying multiple files, the destination path must refer to an existing remote directory. If no destination path is provided, the file is copied into the current remote working directory. If preserve is ``True``, the access and modification times and permissions of the original file are set on the copied file. If recurse is ``True`` and the source path points at a directory, the entire subtree under that directory is copied. If follow_symlinks is set to ``True``, symbolic links found in the source will have the contents of their target copied rather than creating a copy of the symbolic link. When using this option during a recursive copy, one needs to watch out for links that result in loops. The block_size value controls the size of read and write operations issued to copy the files. It defaults to 16 KB. If progress_handler is specified, it will be called after each block of a file is successfully copied. The arguments passed to this handler will be the source path, destination path, bytes copied so far, and total bytes in the file being copied. If multiple source paths are provided or recurse is set to ``True``, the progress_handler will be called consecutively on each file being copied. If error_handler is specified and an error occurs during the copy, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple source paths are provided or when recurse is set to ``True``, to allow error information to be collected without aborting the copy of the remaining files. The error handler can raise an exception if it wants the copy to completely stop. Otherwise, after an error, the copy will continue starting with the next file. :param srcpaths: The paths of the remote files or directories to copy :param dstpath: (optional) The path of the remote file or directory to copy into :param bool preserve: (optional) Whether or not to preserve the original file attributes :param bool recurse: (optional) Whether or not to recursively copy directories :param bool follow_symlinks: (optional) Whether or not to follow symbolic links :param int block_size: (optional) The block size to use for file reads and writes :param callable progress_handler: (optional) The function to call to report copy progress :param callable error_handler: (optional) The function to call when an error occurs :type srcpaths: str or bytes, or a sequence of these :type dstpath: str or bytes :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error """ yield from self._begin_copy(self, self, srcpaths, dstpath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) @asyncio.coroutine def mget(self, remotepaths, localpath=None, *, preserve=False, recurse=False, follow_symlinks=False, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): """Download remote files with glob pattern match This method downloads files and directories from the remote system matching one or more glob patterns. The arguments to this method are identical to the :meth:`get` method, except that the remote paths specified can contain '*' and '?' wildcard characters. """ matches = yield from self._glob(self, remotepaths, error_handler) yield from self._begin_copy(self, LocalFile, matches, localpath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) @asyncio.coroutine def mput(self, localpaths, remotepath=None, *, preserve=False, recurse=False, follow_symlinks=False, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): """Upload local files with glob pattern match This method uploads files and directories to the remote system matching one or more glob patterns. The arguments to this method are identical to the :meth:`put` method, except that the local paths specified can contain '*' and '?' wildcard characters. """ matches = yield from self._glob(LocalFile, localpaths, error_handler) yield from self._begin_copy(LocalFile, self, matches, remotepath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) @asyncio.coroutine def mcopy(self, srcpaths, dstpath=None, *, preserve=False, recurse=False, follow_symlinks=False, block_size=SFTP_BLOCK_SIZE, progress_handler=None, error_handler=None): """Download remote files with glob pattern match This method copies files and directories on the remote system matching one or more glob patterns. The arguments to this method are identical to the :meth:`copy` method, except that the source paths specified can contain '*' and '?' wildcard characters. """ matches = yield from self._glob(self, srcpaths, error_handler) yield from self._begin_copy(self, self, matches, dstpath, preserve, recurse, follow_symlinks, block_size, progress_handler, error_handler) @asyncio.coroutine def glob(self, patterns, error_handler=None): """Match remote files against glob patterns This method matches remote files against one or more glob patterns. Either a single pattern or a sequence of patterns can be provided to match against. If error_handler is specified and an error occurs during the match, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple patterns are provided to allow error information to be collected without aborting the match against the remaining patterns. The error handler can raise an exception if it wants to completely abort the match. Otherwise, after an error, the match will continue starting with the next pattern. An error will be raised if any of the patterns completely fail to match, and this can either stop the match against the remaining patterns or be handled by the error_handler just like other errors. :param patterns: Glob patterns to try and match remote files against :param callable error_handler: (optional) The function to call when an error occurs :type patterns: str or bytes, or a sequence of these :raises: :exc:`SFTPError` if the server returns an error or no match is found """ return (yield from self._glob(self, patterns, error_handler)) @async_context_manager def open(self, path, pflags_or_mode=FXF_READ, attrs=SFTPAttrs(), encoding='utf-8', errors='strict'): """Open a remote file This method opens a remote file and returns an :class:`SFTPClientFile` object which can be used to read and write data and get and set file attributes. The path can be either a str or bytes value. If it is a str, it will be encoded using the file encoding specified when the :class:`SFTPClient` was started. The following open mode flags are supported: ========== ====================================================== Mode Description ========== ====================================================== FXF_READ Open the file for reading. FXF_WRITE Open the file for writing. If both this and FXF_READ are set, open the file for both reading and writing. FXF_APPEND Force writes to append data to the end of the file regardless of seek position. FXF_CREAT Create the file if it doesn't exist. Without this, attempts to open a non-existent file will fail. FXF_TRUNC Truncate the file to zero length if it already exists. FXF_EXCL Return an error when trying to open a file which already exists. ========== ====================================================== By default, file data is read and written as strings in UTF-8 format with strict error checking, but this can be changed using the ``encoding`` and ``errors`` parameters. To read and write data as bytes in binary format, an ``encoding`` value of ``None`` can be used. Instead of these flags, a Python open mode string can also be provided. Python open modes map to the above flags as follows: ==== ============================================= Mode Flags ==== ============================================= r FXF_READ w FXF_WRITE | FXF_CREAT | FXF_TRUNC a FXF_WRITE | FXF_CREAT | FXF_APPEND x FXF_WRITE | FXF_CREAT | FXF_EXCL r+ FXF_READ | FXF_WRITE w+ FXF_READ | FXF_WRITE | FXF_CREAT | FXF_TRUNC a+ FXF_READ | FXF_WRITE | FXF_CREAT | FXF_APPEND x+ FXF_READ | FXF_WRITE | FXF_CREAT | FXF_EXCL ==== ============================================= Including a 'b' in the mode causes the ``encoding`` to be set to ``None``, forcing all data to be read and written as bytes in binary format. The attrs argument is used to set initial attributes of the file if it needs to be created. Otherwise, this argument is ignored. :param path: The name of the remote file to open :param pflags_or_mode: (optional) The access mode to use for the remote file (see above) :param attrs: (optional) File attributes to use if the file needs to be created :param str encoding: (optional) The Unicode encoding to use for data read and written to the remote file :param str errors: (optional) The error-handling mode if an invalid Unicode byte sequence is detected, defaulting to 'strict' which raises an exception :type path: str or bytes :type pflags_or_mode: int or str :type attrs: :class:`SFTPAttrs` :returns: An :class:`SFTPClientFile` to use to access the file :raises: | :exc:`ValueError` if the mode is not valid | :exc:`SFTPError` if the server returns an error """ if isinstance(pflags_or_mode, str): pflags, binary = _mode_to_pflags(pflags_or_mode) if binary: encoding = None else: pflags = pflags_or_mode path = self.compose_path(path) handle = yield from self._handler.open(path, pflags, attrs) return SFTPClientFile(self._handler, handle, pflags & FXF_APPEND, encoding, errors) @asyncio.coroutine def stat(self, path): """Get attributes of a remote file or directory, following symlinks This method queries the attributes of a remote file or directory. If the path provided is a symbolic link, the returned attributes will correspond to the target of the link. :param path: The path of the remote file or directory to get attributes for :type path: str or bytes :returns: An :class:`SFTPAttrs` containing the file attributes :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) return (yield from self._handler.stat(path)) @asyncio.coroutine def lstat(self, path): """Get attributes of a remote file, directory, or symlink This method queries the attributes of a remote file, directory, or symlink. Unlike :meth:`stat`, this method returns the attributes of a symlink itself rather than the target of that link. :param path: The path of the remote file, directory, or link to get attributes for :type path: str or bytes :returns: An :class:`SFTPAttrs` containing the file attributes :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) return (yield from self._handler.lstat(path)) @asyncio.coroutine def setstat(self, path, attrs): """Set attributes of a remote file or directory This method sets attributes of a remote file or directory. If the path provided is a symbolic link, the attributes will be set on the target of the link. A subset of the fields in ``attrs`` can be initialized and only those attributes will be changed. :param path: The path of the remote file or directory to set attributes for :param attrs: File attributes to set :type path: str or bytes :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) yield from self._handler.setstat(path, attrs) @asyncio.coroutine def statvfs(self, path): """Get attributes of a remote file system This method queries the attributes of the file system containing the specified path. :param path: The path of the remote file system to get attributes for :type path: str or bytes :returns: An :class:`SFTPVFSAttrs` containing the file system attributes :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ path = self.compose_path(path) return (yield from self._handler.statvfs(path)) @asyncio.coroutine def truncate(self, path, size): """Truncate a remote file to the specified size This method truncates a remote file to the specified size. If the path provided is a symbolic link, the target of the link will be truncated. :param path: The path of the remote file to be truncated :param int size: The desired size of the file, in bytes :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ yield from self.setstat(path, SFTPAttrs(size=size)) @asyncio.coroutine def chown(self, path, uid, gid): """Change the owner user and group id of a remote file or directory This method changes the user and group id of a remote file or directory. If the path provided is a symbolic link, the target of the link will be changed. :param path: The path of the remote file to change :param int uid: The new user id to assign to the file :param int gid: The new group id to assign to the file :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ yield from self.setstat(path, SFTPAttrs(uid=uid, gid=gid)) @asyncio.coroutine def chmod(self, path, mode): """Change the file permissions of a remote file or directory This method changes the permissions of a remote file or directory. If the path provided is a symbolic link, the target of the link will be changed. :param path: The path of the remote file to change :param int mode: The new file permissions, expressed as an int :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ yield from self.setstat(path, SFTPAttrs(permissions=mode)) @asyncio.coroutine def utime(self, path, times=None): """Change the access and modify times of a remote file or directory This method changes the access and modify times of a remote file or directory. If ``times`` is not provided, the times will be changed to the current time. If the path provided is a symbolic link, the target of the link will be changed. :param path: The path of the remote file to change :param times: (optional) The new access and modify times, as seconds relative to the UNIX epoch :type path: str or bytes :type times: tuple of two int or float values :raises: :exc:`SFTPError` if the server returns an error """ # pylint: disable=unpacking-non-sequence if times is None: atime = mtime = time.time() else: atime, mtime = times yield from self.setstat(path, SFTPAttrs(atime=atime, mtime=mtime)) @asyncio.coroutine def exists(self, path): """Return if the remote path exists and isn't a broken symbolic link :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return bool((yield from self._mode(path))) @asyncio.coroutine def lexists(self, path): """Return if the remote path exists, without following symbolic links :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return bool((yield from self._mode(path, statfunc=self.lstat))) @asyncio.coroutine def getatime(self, path): """Return the last access time of a remote file or directory :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return (yield from self.stat(path)).atime @asyncio.coroutine def getmtime(self, path): """Return the last modification time of a remote file or directory :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return (yield from self.stat(path)).mtime @asyncio.coroutine def getsize(self, path): """Return the size of a remote file or directory :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return (yield from self.stat(path)).size @asyncio.coroutine def isdir(self, path): """Return if the remote path refers to a directory :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return stat.S_ISDIR((yield from self._mode(path))) @asyncio.coroutine def isfile(self, path): """Return if the remote path refers to a regular file :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return stat.S_ISREG((yield from self._mode(path))) @asyncio.coroutine def islink(self, path): """Return if the remote path refers to a symbolic link :param path: The remote path to check :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ return stat.S_ISLNK((yield from self._mode(path, statfunc=self.lstat))) @asyncio.coroutine def remove(self, path): """Remove a remote file This method removes a remote file or symbolic link. :param path: The path of the remote file or link to remove :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) yield from self._handler.remove(path) @asyncio.coroutine def unlink(self, path): """Remove a remote file (see :meth:`remove`)""" yield from self.remove(path) @asyncio.coroutine def rename(self, oldpath, newpath): """Rename a remote file, directory, or link This method renames a remote file, directory, or link. .. note:: This requests the standard SFTP version of rename which will not overwrite the new path if it already exists. To request POSIX behavior where the new path is removed before the rename, use :meth:`posix_rename`. :param oldpath: The path of the remote file, directory, or link to rename :param newpath: The new name for this file, directory, or link :type oldpath: str or bytes :type newpath: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ oldpath = self.compose_path(oldpath) newpath = self.compose_path(newpath) yield from self._handler.rename(oldpath, newpath) @asyncio.coroutine def posix_rename(self, oldpath, newpath): """Rename a remote file, directory, or link with POSIX semantics This method renames a remote file, directory, or link, removing the prior instance of new path if it previously existed. This method may not be supported by all SFTP servers. :param oldpath: The path of the remote file, directory, or link to rename :param newpath: The new name for this file, directory, or link :type oldpath: str or bytes :type newpath: str or bytes :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ oldpath = self.compose_path(oldpath) newpath = self.compose_path(newpath) yield from self._handler.posix_rename(oldpath, newpath) @asyncio.coroutine def readdir(self, path='.'): """Read the contents of a remote directory This method reads the contents of a directory, returning the names and attributes of what is contained there. If no path is provided, it defaults to the current remote working directory. :param path: (optional) The path of the remote directory to read :type path: str or bytes :returns: A list of :class:`SFTPName` entries, with path names matching the type used to pass in the path :raises: :exc:`SFTPError` if the server returns an error """ names = [] dirpath = self.compose_path(path) handle = yield from self._handler.opendir(dirpath) try: while True: names.extend((yield from self._handler.readdir(handle))) except SFTPError as exc: if exc.code != FX_EOF: raise finally: yield from self._handler.close(handle) if isinstance(path, str): for name in names: name.filename = self.decode(name.filename) name.longname = self.decode(name.longname) return names @asyncio.coroutine def listdir(self, path='.'): """Read the names of the files in a remote directory This method reads the names of files and subdirectories in a remote directory. If no path is provided, it defaults to the current remote working directory. :param path: (optional) The path of the remote directory to read :type path: str or bytes :returns: A list of file/subdirectory names, matching the type used to pass in the path :raises: :exc:`SFTPError` if the server returns an error """ names = yield from self.readdir(path) return [name.filename for name in names] @asyncio.coroutine def mkdir(self, path, attrs=SFTPAttrs()): """Create a remote directory with the specified attributes This method creates a new remote directory at the specified path with the requested attributes. :param path: The path of where the new remote directory should be created :param attrs: (optional) The file attributes to use when creating the directory :type path: str or bytes :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) yield from self._handler.mkdir(path, attrs) @asyncio.coroutine def rmdir(self, path): """Remove a remote directory This method removes a remote directory. The directory must be empty for the removal to succeed. :param path: The path of the remote directory to remove :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) yield from self._handler.rmdir(path) @asyncio.coroutine def realpath(self, path): """Return the canonical version of a remote path This method returns a canonical version of the requested path. :param path: (optional) The path of the remote directory to canonicalize :type path: str or bytes :returns: The canonical path as a str or bytes, matching the type used to pass in the path :raises: :exc:`SFTPError` if the server returns an error """ fullpath = self.compose_path(path) names = yield from self._handler.realpath(fullpath) if len(names) > 1: raise SFTPError(FX_BAD_MESSAGE, 'Too many names returned') return self.decode(names[0].filename, isinstance(path, str)) @asyncio.coroutine def getcwd(self): """Return the current remote working directory :returns: The current remote working directory, decoded using the specified path encoding :raises: :exc:`SFTPError` if the server returns an error """ if self._cwd is None: self._cwd = yield from self.realpath(b'.') return self.decode(self._cwd) @asyncio.coroutine def chdir(self, path): """Change the current remote working directory :param path: The path to set as the new remote working directory :type path: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ self._cwd = yield from self.realpath(self.encode(path)) @asyncio.coroutine def readlink(self, path): """Return the target of a remote symbolic link This method returns the target of a symbolic link. :param path: The path of the remote symbolic link to follow :type path: str or bytes :returns: The target path of the link as a str or bytes :raises: :exc:`SFTPError` if the server returns an error """ linkpath = self.compose_path(path) names = yield from self._handler.readlink(linkpath) if len(names) > 1: raise SFTPError(FX_BAD_MESSAGE, 'Too many names returned') return self.decode(names[0].filename, isinstance(path, str)) @asyncio.coroutine def symlink(self, oldpath, newpath): """Create a remote symbolic link This method creates a symbolic link. The argument order here matches the standard Python :meth:`os.symlink` call. The argument order sent on the wire is automatically adapted depending on the version information sent by the server, as a number of servers (OpenSSH in particular) did not follow the SFTP standard when implementing this call. :param oldpath: The path the link should point to :param newpath: The path of where to create the remote symbolic link :type oldpath: str or bytes :type newpath: str or bytes :raises: :exc:`SFTPError` if the server returns an error """ oldpath = self.compose_path(oldpath) newpath = self.encode(newpath) yield from self._handler.symlink(oldpath, newpath) @asyncio.coroutine def link(self, oldpath, newpath): """Create a remote hard link This method creates a hard link to the remote file specified by oldpath at the location specified by newpath. This method may not be supported by all SFTP servers. :param oldpath: The path of the remote file the hard link should point to :param newpath: The path of where to create the remote hard link :type oldpath: str or bytes :type newpath: str or bytes :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ oldpath = self.compose_path(oldpath) newpath = self.compose_path(newpath) yield from self._handler.link(oldpath, newpath) def exit(self): """Exit the SFTP client session This method exits the SFTP client session, closing the corresponding channel opened on the server. """ self._handler.exit() @asyncio.coroutine def wait_closed(self): """Wait for this SFTP client session to close""" yield from self._handler.wait_closed() class SFTPServerHandler(SFTPHandler): """An SFTP server session handler""" _extensions = [(b'posix-rename@openssh.com', b'1'), (b'hardlink@openssh.com', b'1'), (b'fsync@openssh.com', b'1')] if hasattr(os, 'statvfs'): # pragma: no branch _extensions += [(b'statvfs@openssh.com', b'2'), (b'fstatvfs@openssh.com', b'2')] def __init__(self, server, reader, writer): super().__init__(reader, writer) self._server = server self._version = None self._nonstandard_symlink = False self._next_handle = 0 self._file_handles = {} self._dir_handles = {} @asyncio.coroutine def _cleanup(self, exc): """Clean up this SFTP server session""" if self._server: # pragma: no branch for file_obj in self._file_handles.values(): result = self._server.close(file_obj) if asyncio.iscoroutine(result): result = yield from result result = self._server.exit() if asyncio.iscoroutine(result): result = yield from result self._server = None self._file_handles = [] self._dir_handles = [] yield from super()._cleanup(exc) def _get_next_handle(self): """Get the next available unique file handle number""" while True: handle = self._next_handle.to_bytes(4, 'big') self._next_handle = (self._next_handle + 1) & 0xffffffff if (handle not in self._file_handles and handle not in self._dir_handles): return handle @asyncio.coroutine def _process_packet(self, pkttype, pktid, packet): """Process incoming SFTP requests""" # pylint: disable=broad-except try: if pkttype == FXP_EXTENDED: pkttype = packet.get_string() handler = self._packet_handlers.get(pkttype) if not handler: raise SFTPError(FX_OP_UNSUPPORTED, 'Unsupported request type: %s' % pkttype) return_type = self._return_types.get(pkttype, FXP_STATUS) result = yield from handler(self, packet) if return_type == FXP_STATUS: result = UInt32(FX_OK) + String('') + String('') elif return_type in (FXP_HANDLE, FXP_DATA): result = String(result) elif return_type == FXP_NAME: result = (UInt32(len(result)) + b''.join(name.encode() for name in result)) else: if isinstance(result, os.stat_result): result = SFTPAttrs.from_local(result) elif isinstance(result, os.statvfs_result): result = SFTPVFSAttrs.from_local(result) result = result.encode() except PacketDecodeError as exc: return_type = FXP_STATUS result = (UInt32(FX_BAD_MESSAGE) + String(str(exc)) + String(DEFAULT_LANG)) except SFTPError as exc: return_type = FXP_STATUS result = UInt32(exc.code) + String(exc.reason) + String(exc.lang) except NotImplementedError as exc: name = handler.__name__[9:] return_type = FXP_STATUS result = (UInt32(FX_OP_UNSUPPORTED) + String('Operation not supported: %s' % name) + String(DEFAULT_LANG)) except OSError as exc: return_type = FXP_STATUS if exc.errno in (errno.ENOENT, errno.ENOTDIR): code = FX_NO_SUCH_FILE elif exc.errno == errno.EACCES: code = FX_PERMISSION_DENIED else: code = FX_FAILURE result = (UInt32(code) + String(exc.strerror or str(exc)) + String(DEFAULT_LANG)) except Exception as exc: # pragma: no cover return_type = FXP_STATUS result = (UInt32(FX_FAILURE) + String('Uncaught exception: %s' % str(exc)) + String(DEFAULT_LANG)) self.send_packet(Byte(return_type), UInt32(pktid), result) @asyncio.coroutine def _process_open(self, packet): """Process an incoming SFTP open request""" path = packet.get_string() pflags = packet.get_uint32() attrs = SFTPAttrs.decode(packet) packet.check_end() result = self._server.open(path, pflags, attrs) if asyncio.iscoroutine(result): result = yield from result handle = self._get_next_handle() self._file_handles[handle] = result return handle @asyncio.coroutine def _process_close(self, packet): """Process an incoming SFTP close request""" handle = packet.get_string() packet.check_end() file_obj = self._file_handles.pop(handle, None) if file_obj: result = self._server.close(file_obj) if asyncio.iscoroutine(result): yield from result return if self._dir_handles.pop(handle, None) is not None: return raise SFTPError(FX_FAILURE, 'Invalid file handle') @asyncio.coroutine def _process_read(self, packet): """Process an incoming SFTP read request""" handle = packet.get_string() offset = packet.get_uint64() length = packet.get_uint32() packet.check_end() file_obj = self._file_handles.get(handle) if file_obj: result = self._server.read(file_obj, offset, length) if asyncio.iscoroutine(result): result = yield from result if result: return result else: raise SFTPError(FX_EOF, '') else: raise SFTPError(FX_FAILURE, 'Invalid file handle') @asyncio.coroutine def _process_write(self, packet): """Process an incoming SFTP write request""" handle = packet.get_string() offset = packet.get_uint64() data = packet.get_string() packet.check_end() file_obj = self._file_handles.get(handle) if file_obj: result = self._server.write(file_obj, offset, data) if asyncio.iscoroutine(result): result = yield from result return result else: raise SFTPError(FX_FAILURE, 'Invalid file handle') @asyncio.coroutine def _process_lstat(self, packet): """Process an incoming SFTP lstat request""" path = packet.get_string() packet.check_end() result = self._server.lstat(path) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_fstat(self, packet): """Process an incoming SFTP fstat request""" handle = packet.get_string() packet.check_end() file_obj = self._file_handles.get(handle) if file_obj: result = self._server.fstat(file_obj) if asyncio.iscoroutine(result): result = yield from result return result else: raise SFTPError(FX_FAILURE, 'Invalid file handle') @asyncio.coroutine def _process_setstat(self, packet): """Process an incoming SFTP setstat request""" path = packet.get_string() attrs = SFTPAttrs.decode(packet) packet.check_end() result = self._server.setstat(path, attrs) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_fsetstat(self, packet): """Process an incoming SFTP fsetstat request""" handle = packet.get_string() attrs = SFTPAttrs.decode(packet) packet.check_end() file_obj = self._file_handles.get(handle) if file_obj: result = self._server.fsetstat(file_obj, attrs) if asyncio.iscoroutine(result): result = yield from result return result else: raise SFTPError(FX_FAILURE, 'Invalid file handle') @asyncio.coroutine def _process_opendir(self, packet): """Process an incoming SFTP opendir request""" path = packet.get_string() packet.check_end() listdir_result = self._server.listdir(path) if asyncio.iscoroutine(listdir_result): listdir_result = yield from listdir_result for i, name in enumerate(listdir_result): # pylint: disable=no-member if isinstance(name, bytes): name = SFTPName(name) listdir_result[i] = name # pylint: disable=attribute-defined-outside-init filename = os.path.join(path, name.filename) attr_result = self._server.lstat(filename) if asyncio.iscoroutine(attr_result): attr_result = yield from attr_result if isinstance(attr_result, os.stat_result): attr_result = SFTPAttrs.from_local(attr_result) name.attrs = attr_result if not name.longname: longname_result = self._server.format_longname(name) if asyncio.iscoroutine(longname_result): yield from longname_result handle = self._get_next_handle() self._dir_handles[handle] = listdir_result return handle @asyncio.coroutine def _process_readdir(self, packet): """Process an incoming SFTP readdir request""" handle = packet.get_string() packet.check_end() names = self._dir_handles.get(handle) if names: result = names[:_MAX_READDIR_NAMES] del names[:_MAX_READDIR_NAMES] return result else: raise SFTPError(FX_EOF, '') @asyncio.coroutine def _process_remove(self, packet): """Process an incoming SFTP remove request""" path = packet.get_string() packet.check_end() result = self._server.remove(path) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_mkdir(self, packet): """Process an incoming SFTP mkdir request""" path = packet.get_string() attrs = SFTPAttrs.decode(packet) packet.check_end() result = self._server.mkdir(path, attrs) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_rmdir(self, packet): """Process an incoming SFTP rmdir request""" path = packet.get_string() packet.check_end() result = self._server.rmdir(path) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_realpath(self, packet): """Process an incoming SFTP realpath request""" path = packet.get_string() packet.check_end() result = self._server.realpath(path) if asyncio.iscoroutine(result): result = yield from result return [SFTPName(result)] @asyncio.coroutine def _process_stat(self, packet): """Process an incoming SFTP stat request""" path = packet.get_string() packet.check_end() result = self._server.stat(path) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_rename(self, packet): """Process an incoming SFTP rename request""" oldpath = packet.get_string() newpath = packet.get_string() packet.check_end() result = self._server.rename(oldpath, newpath) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_readlink(self, packet): """Process an incoming SFTP readlink request""" path = packet.get_string() packet.check_end() result = self._server.readlink(path) if asyncio.iscoroutine(result): result = yield from result return [SFTPName(result)] @asyncio.coroutine def _process_symlink(self, packet): """Process an incoming SFTP symlink request""" if self._nonstandard_symlink: oldpath = packet.get_string() newpath = packet.get_string() else: newpath = packet.get_string() oldpath = packet.get_string() packet.check_end() result = self._server.symlink(oldpath, newpath) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_posix_rename(self, packet): """Process an incoming SFTP POSIX rename request""" oldpath = packet.get_string() newpath = packet.get_string() packet.check_end() result = self._server.posix_rename(oldpath, newpath) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_statvfs(self, packet): """Process an incoming SFTP statvfs request""" path = packet.get_string() packet.check_end() result = self._server.statvfs(path) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_fstatvfs(self, packet): """Process an incoming SFTP fstatvfs request""" handle = packet.get_string() packet.check_end() file_obj = self._file_handles.get(handle) if file_obj: result = self._server.fstatvfs(file_obj) if asyncio.iscoroutine(result): result = yield from result return result else: raise SFTPError(FX_FAILURE, 'Invalid file handle') @asyncio.coroutine def _process_link(self, packet): """Process an incoming SFTP hard link request""" oldpath = packet.get_string() newpath = packet.get_string() packet.check_end() result = self._server.link(oldpath, newpath) if asyncio.iscoroutine(result): result = yield from result return result @asyncio.coroutine def _process_fsync(self, packet): """Process an incoming SFTP fsync request""" handle = packet.get_string() packet.check_end() file_obj = self._file_handles.get(handle) if file_obj: result = self._server.fsync(file_obj) if asyncio.iscoroutine(result): result = yield from result return result else: raise SFTPError(FX_FAILURE, 'Invalid file handle') _packet_handlers = { FXP_OPEN: _process_open, FXP_CLOSE: _process_close, FXP_READ: _process_read, FXP_WRITE: _process_write, FXP_LSTAT: _process_lstat, FXP_FSTAT: _process_fstat, FXP_SETSTAT: _process_setstat, FXP_FSETSTAT: _process_fsetstat, FXP_OPENDIR: _process_opendir, FXP_READDIR: _process_readdir, FXP_REMOVE: _process_remove, FXP_MKDIR: _process_mkdir, FXP_RMDIR: _process_rmdir, FXP_REALPATH: _process_realpath, FXP_STAT: _process_stat, FXP_RENAME: _process_rename, FXP_READLINK: _process_readlink, FXP_SYMLINK: _process_symlink, b'posix-rename@openssh.com': _process_posix_rename, b'statvfs@openssh.com': _process_statvfs, b'fstatvfs@openssh.com': _process_fstatvfs, b'hardlink@openssh.com': _process_link, b'fsync@openssh.com': _process_fsync } @asyncio.coroutine def run(self): """Run an SFTP server""" try: packet = yield from self.recv_packet() pkttype = packet.get_byte() version = packet.get_uint32() except PacketDecodeError as exc: yield from self._cleanup(SFTPError(FX_BAD_MESSAGE, str(exc))) return except SFTPError as exc: yield from self._cleanup(exc) return if pkttype != FXP_INIT: yield from self._cleanup(SFTPError(FX_BAD_MESSAGE, 'Expected init message')) return version = min(version, _SFTP_VERSION) extensions = (String(name) + String(data) for name, data in self._extensions) try: self.send_packet(Byte(FXP_VERSION), UInt32(version), *extensions) except SFTPError as exc: yield from self._cleanup(exc) return if version == 3: # Check if the server has a buggy SYMLINK implementation client_version = self._reader.get_extra_info('client_version', '') if any(name in client_version for name in self._nonstandard_symlink_impls): self._nonstandard_symlink = True yield from self.recv_packets() class SFTPServer: """SFTP server Applications should subclass this when implementing an SFTP server. The methods listed below should be implemented to provide the desired application behavior. .. note:: Any method can optionally be defined as a coroutine if that method needs to perform blocking opertions to determine its result. The ``conn`` object provided here refers to the :class:`SSHServerConnection` instance this SFTP server is associated with. It can be queried to determine which user the client authenticated as or to request key and certificate options or permissions which should be applied to this session. If the ``chroot`` argument is specified when this object is created, the default :meth:`map_path` and :meth:`reverse_map_path` methods will enforce a virtual root directory starting in that location, limiting access to only files within that directory tree. This will also affect path names returned by the :meth:`realpath` and :meth:`readlink` methods. """ # The default implementation of a number of these methods don't need self # pylint: disable=no-self-use def __init__(self, conn, chroot=None): # pylint: disable=unused-argument if chroot: self._chroot = _from_local_path(os.path.realpath(chroot)) else: self._chroot = None def format_user(self, uid): """Return the user name associated with a uid This method returns a user name string to insert into the ``longname`` field of an :class:`SFTPName` object. By default, it calls the Python :func:`pwd.getpwuid` function if it is available, or returns the numeric uid as a string if not. If there is no uid, it returns an empty string. :param uid: The uid value to look up :type uid: int or ``None`` :returns: The formatted user name string """ if uid is not None: try: import pwd user = pwd.getpwuid(uid).pw_name except (ImportError, KeyError): user = str(uid) else: user = '' return user def format_group(self, gid): """Return the group name associated with a gid This method returns a group name string to insert into the ``longname`` field of an :class:`SFTPName` object. By default, it calls the Python :func:`grp.getgrgid` function if it is available, or returns the numeric gid as a string if not. If there is no gid, it returns an empty string. :param gid: The gid value to look up :type gid: int or ``None`` :returns: The formatted group name string """ if gid is not None: try: import grp group = grp.getgrgid(gid).gr_name except (ImportError, KeyError): group = str(gid) else: group = '' return group def format_longname(self, name): """Format the long name associated with an SFTP name This method fills in the ``longname`` field of a :class:`SFTPName` object. By default, it generates something similar to UNIX "ls -l" output. The ``filename`` and ``attrs`` fields of the :class:`SFTPName` should already be filled in before this method is called. :param name: The :class:`SFTPName` instance to format the long name for :type name: :class:`SFTPName` """ if name.attrs.permissions is not None: mode = stat.filemode(name.attrs.permissions) else: mode = '' nlink = str(name.attrs.nlink) if name.attrs.nlink else '' user = self.format_user(name.attrs.uid) group = self.format_group(name.attrs.gid) size = str(name.attrs.size) if name.attrs.size is not None else '' if name.attrs.mtime is not None: now = time.time() mtime = time.localtime(name.attrs.mtime) modtime = time.strftime('%b ', mtime) try: modtime += time.strftime('%e', mtime) except ValueError: modtime += time.strftime('%d', mtime) if now - 365*24*60*60/2 < name.attrs.mtime <= now: modtime += time.strftime(' %H:%M', mtime) else: modtime += time.strftime(' %Y', mtime) else: modtime = '' detail = '{:10s} {:>4s} {:8s} {:8s} {:>8s} {:12s} '.format( mode, nlink, user, group, size, modtime) name.longname = detail.encode('utf-8') + name.filename def map_path(self, path): """Map the path requested by the client to a local path This method can be overridden to provide a custom mapping from path names requested by the client to paths in the local filesystem. By default, it will enforce a virtual "chroot" if one was specified when this server was created. Otherwise, path names are left unchanged, with relative paths being interpreted based on the working directory of the currently running process. :param bytes path: The path name to map :returns: bytes containing the local path name to operate on """ if self._chroot: normpath = posixpath.normpath(os.path.join(b'/', path)) return posixpath.join(self._chroot, normpath[1:]) else: return path def reverse_map_path(self, path): """Reverse map a local path into the path reported to the client This method can be overridden to provide a custom reverse mapping for the mapping provided by :meth:`map_path`. By default, it hides the portion of the local path associated with the virtual "chroot" if one was specified. :param bytes path: The local path name to reverse map :returns: bytes containing the path name to report to the client """ if self._chroot: if path == self._chroot: return b'/' elif path.startswith(self._chroot + b'/'): return path[len(self._chroot):] else: raise SFTPError(FX_NO_SUCH_FILE, 'File not found') else: return path def open(self, path, pflags, attrs): """Open a file to serve to a remote client This method returns a file object which can be used to read and write data and get and set file attributes. The possible open mode flags and their meanings are: ========== ====================================================== Mode Description ========== ====================================================== FXF_READ Open the file for reading. If neither FXF_READ nor FXF_WRITE are set, this is the default. FXF_WRITE Open the file for writing. If both this and FXF_READ are set, open the file for both reading and writing. FXF_APPEND Force writes to append data to the end of the file regardless of seek position. FXF_CREAT Create the file if it doesn't exist. Without this, attempts to open a non-existent file will fail. FXF_TRUNC Truncate the file to zero length if it already exists. FXF_EXCL Return an error when trying to open a file which already exists. ========== ====================================================== The attrs argument is used to set initial attributes of the file if it needs to be created. Otherwise, this argument is ignored. :param bytes path: The name of the file to open :param int pflags: The access mode to use for the file (see above) :param attrs: File attributes to use if the file needs to be created :type attrs: :class:`SFTPAttrs` :returns: A file object to use to access the file :raises: :exc:`SFTPError` to return an error to the client """ if pflags & FXF_EXCL: mode = 'xb' elif pflags & FXF_APPEND: mode = 'ab' elif pflags & FXF_WRITE and not pflags & FXF_READ: mode = 'wb' else: mode = 'rb' if pflags & FXF_READ and pflags & FXF_WRITE: mode += '+' flags = os.O_RDWR elif pflags & FXF_WRITE: flags = os.O_WRONLY else: flags = os.O_RDONLY if pflags & FXF_APPEND: flags |= os.O_APPEND if pflags & FXF_CREAT: flags |= os.O_CREAT if pflags & FXF_TRUNC: flags |= os.O_TRUNC if pflags & FXF_EXCL: flags |= os.O_EXCL flags |= getattr(os, 'O_BINARY', 0) perms = 0o666 if attrs.permissions is None else attrs.permissions return open(_to_local_path(self.map_path(path)), mode, buffering=0, opener=lambda path, _: os.open(path, flags, perms)) def close(self, file_obj): """Close an open file or directory :param file file_obj: The file or directory object to close :raises: :exc:`SFTPError` to return an error to the client """ file_obj.close() def read(self, file_obj, offset, size): """Read data from an open file :param file file_obj: The file to read from :param int offset: The offset from the beginning of the file to begin reading :param int size: The number of bytes to read :returns: bytes read from the file :raises: :exc:`SFTPError` to return an error to the client """ file_obj.seek(offset) return file_obj.read(size) def write(self, file_obj, offset, data): """Write data to an open file :param file file_obj: The file to write to :param int offset: The offset from the beginning of the file to begin writing :param bytes data: The data to write to the file :returns: number of bytes written :raises: :exc:`SFTPError` to return an error to the client """ file_obj.seek(offset) return file_obj.write(data) def lstat(self, path): """Get attributes of a file, directory, or symlink This method queries the attributes of a file, directory, or symlink. Unlike :meth:`stat`, this method should return the attributes of a symlink itself rather than the target of that link. :param bytes path: The path of the file, directory, or link to get attributes for :returns: An :class:`SFTPAttrs` or an os.stat_result containing the file attributes :raises: :exc:`SFTPError` to return an error to the client """ return os.lstat(_to_local_path(self.map_path(path))) def fstat(self, file_obj): """Get attributes of an open file :param file file_obj: The file to get attributes for :returns: An :class:`SFTPAttrs` or an os.stat_result containing the file attributes :raises: :exc:`SFTPError` to return an error to the client """ file_obj.flush() return os.fstat(file_obj.fileno()) def setstat(self, path, attrs): """Set attributes of a file or directory This method sets attributes of a file or directory. If the path provided is a symbolic link, the attributes should be set on the target of the link. A subset of the fields in ``attrs`` can be initialized and only those attributes should be changed. :param bytes path: The path of the remote file or directory to set attributes for :param attrs: File attributes to set :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` to return an error to the client """ _setstat(_to_local_path(self.map_path(path)), attrs) def fsetstat(self, file_obj, attrs): """Set attributes of an open file :param attrs: File attributes to set on the file :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` to return an error to the client """ file_obj.flush() if sys.platform == 'win32': # pragma: no cover _setstat(file_obj.name, attrs) else: _setstat(file_obj.fileno(), attrs) def listdir(self, path): """List the contents of a directory :param bytes path: The path of the directory to open :returns: A list of names of files in the directory :raises: :exc:`SFTPError` to return an error to the client """ files = os.listdir(_to_local_path(self.map_path(path))) if sys.platform == 'win32': # pragma: no cover files = [os.fsencode(f) for f in files] return [b'.', b'..'] + files def remove(self, path): """Remove a file or symbolic link :param bytes path: The path of the file or link to remove :raises: :exc:`SFTPError` to return an error to the client """ os.remove(_to_local_path(self.map_path(path))) def mkdir(self, path, attrs): """Create a directory with the specified attributes :param bytes path: The path of where the new directory should be created :param attrs: The file attributes to use when creating the directory :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` to return an error to the client """ mode = 0o777 if attrs.permissions is None else attrs.permissions os.mkdir(_to_local_path(self.map_path(path)), mode) def rmdir(self, path): """Remove a directory :param bytes path: The path of the directory to remove :raises: :exc:`SFTPError` to return an error to the client """ os.rmdir(_to_local_path(self.map_path(path))) def realpath(self, path): """Return the canonical version of a path :param bytes path: The path of the directory to canonicalize :returns: bytes containing the canonical path :raises: :exc:`SFTPError` to return an error to the client """ path = os.path.realpath(_to_local_path(self.map_path(path))) return self.reverse_map_path(_from_local_path(path)) def stat(self, path): """Get attributes of a file or directory, following symlinks This method queries the attributes of a file or directory. If the path provided is a symbolic link, the returned attributes should correspond to the target of the link. :param bytes path: The path of the remote file or directory to get attributes for :returns: An :class:`SFTPAttrs` or an os.stat_result containing the file attributes :raises: :exc:`SFTPError` to return an error to the client """ return os.stat(_to_local_path(self.map_path(path))) def rename(self, oldpath, newpath): """Rename a file, directory, or link This method renames a file, directory, or link. .. note:: This is a request for the standard SFTP version of rename which will not overwrite the new path if it already exists. The :meth:`posix_rename` method will be called if the client requests the POSIX behavior where an existing instance of the new path is removed before the rename. :param bytes oldpath: The path of the file, directory, or link to rename :param bytes newpath: The new name for this file, directory, or link :raises: :exc:`SFTPError` to return an error to the client """ oldpath = _to_local_path(self.map_path(oldpath)) newpath = _to_local_path(self.map_path(newpath)) if os.path.exists(newpath): raise SFTPError(FX_FAILURE, 'File already exists') os.rename(oldpath, newpath) def readlink(self, path): """Return the target of a symbolic link :param bytes path: The path of the symbolic link to follow :returns: bytes containing the target path of the link :raises: :exc:`SFTPError` to return an error to the client """ path = os.readlink(_to_local_path(self.map_path(path))) return self.reverse_map_path(_from_local_path(path)) def symlink(self, oldpath, newpath): """Create a symbolic link :param bytes oldpath: The path the link should point to :param bytes newpath: The path of where to create the symbolic link :raises: :exc:`SFTPError` to return an error to the client """ if posixpath.isabs(oldpath): oldpath = self.map_path(oldpath) else: newdir = posixpath.dirname(newpath) abspath1 = self.map_path(posixpath.join(newdir, oldpath)) mapped_newdir = self.map_path(newdir) abspath2 = os.path.join(mapped_newdir, oldpath) # Make sure the symlink doesn't point outside the chroot if os.path.realpath(abspath1) != os.path.realpath(abspath2): oldpath = os.path.relpath(abspath1, start=mapped_newdir) newpath = self.map_path(newpath) os.symlink(_to_local_path(oldpath), _to_local_path(newpath)) def posix_rename(self, oldpath, newpath): """Rename a file, directory, or link with POSIX semantics This method renames a file, directory, or link, removing the prior instance of new path if it previously existed. :param bytes oldpath: The path of the file, directory, or link to rename :param bytes newpath: The new name for this file, directory, or link :raises: :exc:`SFTPError` to return an error to the client """ oldpath = _to_local_path(self.map_path(oldpath)) newpath = _to_local_path(self.map_path(newpath)) os.replace(oldpath, newpath) def statvfs(self, path): """Get attributes of the file system containing a file :param bytes path: The path of the file system to get attributes for :returns: An :class:`SFTPVFSAttrs` or an os.statvfs_result containing the file system attributes :raises: :exc:`SFTPError` to return an error to the client """ try: return os.statvfs(_to_local_path(self.map_path(path))) except AttributeError: # pragma: no cover raise SFTPError(FX_OP_UNSUPPORTED, 'statvfs not supported') def fstatvfs(self, file_obj): """Return attributes of the file system containing an open file :param file file_obj: The open file to get file system attributes for :returns: An :class:`SFTPVFSAttrs` or an os.statvfs_result containing the file system attributes :raises: :exc:`SFTPError` to return an error to the client """ try: return os.statvfs(file_obj.fileno()) except AttributeError: # pragma: no cover raise SFTPError(FX_OP_UNSUPPORTED, 'fstatvfs not supported') def link(self, oldpath, newpath): """Create a hard link :param bytes oldpath: The path of the file the hard link should point to :param bytes newpath: The path of where to create the hard link :raises: :exc:`SFTPError` to return an error to the client """ oldpath = _to_local_path(self.map_path(oldpath)) newpath = _to_local_path(self.map_path(newpath)) os.link(oldpath, newpath) def fsync(self, file_obj): """Force file data to be written to disk :param file file_obj: The open file containing the data to flush to disk :raises: :exc:`SFTPError` to return an error to the client """ os.fsync(file_obj.fileno()) def exit(self): """Shut down this SFTP server""" pass class SFTPServerFile: """A wrapper around SFTPServer used to access files it manages""" def __init__(self, server): self._server = server self._file_obj = None @asyncio.coroutine def stat(self, path): """Get attributes of a file""" attrs = self._server.stat(path) if asyncio.iscoroutine(attrs): attrs = yield from attrs if isinstance(attrs, os.stat_result): attrs = SFTPAttrs.from_local(attrs) return attrs @asyncio.coroutine def setstat(self, path, attrs): """Set attributes of a file or directory""" result = self._server.setstat(path, attrs) if asyncio.iscoroutine(result): attrs = yield from result @asyncio.coroutine def _mode(self, path): """Return the file mode of a path, or 0 if it can't be accessed""" try: return (yield from self.stat(path)).permissions except OSError as exc: if exc.errno in (errno.ENOENT, errno.EACCES): return 0 else: raise except SFTPError as exc: if exc.code in (FX_NO_SUCH_FILE, FX_PERMISSION_DENIED): return 0 else: raise @asyncio.coroutine def exists(self, path): """Return if a path exists""" return (yield from self._mode(path)) != 0 @asyncio.coroutine def isdir(self, path): """Return if the path refers to a directory""" return stat.S_ISDIR((yield from self._mode(path))) @asyncio.coroutine def mkdir(self, path): """Create a directory""" result = self._server.mkdir(path, SFTPAttrs()) if asyncio.iscoroutine(result): yield from result @asyncio.coroutine def listdir(self, path): """List the contents of a directory""" files = self._server.listdir(path) if asyncio.iscoroutine(files): files = yield from files return files @asyncio.coroutine def open(self, path, mode='rb'): """Open a file""" pflags, _ = _mode_to_pflags(mode) file_obj = self._server.open(path, pflags, SFTPAttrs()) if asyncio.iscoroutine(file_obj): file_obj = yield from file_obj self._file_obj = file_obj return self @asyncio.coroutine def read(self, size, offset): """Read bytes from the file""" data = self._server.read(self._file_obj, offset, size) if asyncio.iscoroutine(data): data = yield from data return data @asyncio.coroutine def write(self, data, offset): """Write bytes to the file""" size = self._server.write(self._file_obj, offset, data) if asyncio.iscoroutine(size): size = yield from size return size @asyncio.coroutine def close(self): """Close a file managed by the associated SFTPServer""" result = self._server.close(self._file_obj) if asyncio.iscoroutine(result): yield from result @asyncio.coroutine def start_sftp_client(conn, loop, reader, writer, path_encoding, path_errors): """Start an SFTP client""" handler = SFTPClientHandler(loop, reader, writer) yield from handler.start() conn.create_task(handler.recv_packets()) return SFTPClient(loop, handler, path_encoding, path_errors) @asyncio.coroutine def run_sftp_server(sftp_server, reader, writer): """Return a handler for an SFTP server session""" handler = SFTPServerHandler(sftp_server, reader, writer) yield from handler.run() asyncssh-1.11.1/asyncssh/stream.py000066400000000000000000000470351320320510200171230ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH stream handlers""" import asyncio from .constants import EXTENDED_DATA_STDERR from .misc import BreakReceived, SignalReceived, TerminalSizeChanged from .misc import async_iterator, python35 from .session import SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession from .sftp import run_sftp_server from .scp import run_scp_server _NEWLINE = object() class SSHReader: """SSH read stream handler""" def __init__(self, session, chan, datatype=None): self._session = session self._chan = chan self._datatype = datatype if python35: @async_iterator def __aiter__(self): """Allow SSHReader to be an async iterator""" return self @asyncio.coroutine def __anext__(self): """Return one line at a time when used as an async iterator""" line = yield from self.readline() if line: return line else: raise StopAsyncIteration @property def channel(self): """The SSH channel associated with this stream""" return self._chan def get_extra_info(self, name, default=None): """Return additional information about this stream This method returns extra information about the channel associated with this stream. See :meth:`get_extra_info() ` on :class:`SSHClientChannel` for additional information. """ return self._chan.get_extra_info(name, default) @asyncio.coroutine def read(self, n=-1): """Read data from the stream This method is a coroutine which reads up to ``n`` bytes or characters from the stream. If ``n`` is not provided or set to ``-1``, it reads until EOF or a signal is received. If EOF is received and the receive buffer is empty, an empty bytes or str object is returned. If the next data in the stream is a signal, the signal is delivered as a raised exception. .. note:: Unlike traditional ``asyncio`` stream readers, the data will be delivered as either bytes or a str depending on whether an encoding was specified when the underlying channel was opened. """ return self._session.read(n, self._datatype, exact=False) @asyncio.coroutine def readline(self): """Read one line from the stream This method is a coroutine which reads one line, ending in ``'\\n'``. If EOF is received before ``'\\n'`` is found, the partial line is returned. If EOF is received and the receive buffer is empty, an empty bytes or str object is returned. If the next data in the stream is a signal, the signal is delivered as a raised exception. .. note:: In Python 3.5 and later, :class:`SSHReader` objects can also be used as async iterators, returning input data one line at a time. """ try: return (yield from self.readuntil(_NEWLINE)) except asyncio.IncompleteReadError as exc: return exc.partial @asyncio.coroutine def readuntil(self, separator): """Read data from the stream until ``separator`` is seen This method is a coroutine which reads from the stream until the requested separator is seen. If a match is found, the returned data will include the separator at the end. If EOF or a signal is received before a match occurs, an :exc:`IncompleteReadError ` is raised and its ``partial`` attribute will contain the data in the stream prior to the EOF or signal. If the next data in the stream is a signal, the signal is delivered as a raised exception. """ return self._session.readuntil(separator, self._datatype) @asyncio.coroutine def readexactly(self, n): """Read an exact amount of data from the stream This method is a coroutine which reads exactly n bytes or characters from the stream. If EOF or a signal is received in the stream before ``n`` bytes are read, an :exc:`IncompleteReadError ` is raised and its ``partial`` attribute will contain the data before the EOF or signal. If the next data in the stream is a signal, the signal is delivered as a raised exception. """ return self._session.read(n, self._datatype, exact=True) def at_eof(self): """Return whether the stream is at EOF This method returns ``True`` when EOF has been received and all data in the stream has been read. """ return self._session.at_eof(self._datatype) def get_redirect_info(self): """Get information needed to redirect from this SSHReader""" return self._session, self._datatype class SSHWriter: """SSH write stream handler""" def __init__(self, session, chan, datatype=None): self._session = session self._chan = chan self._datatype = datatype @property def channel(self): """The SSH channel associated with this stream""" return self._chan def get_extra_info(self, name, default=None): """Return additional information about this stream This method returns extra information about the channel associated with this stream. See :meth:`get_extra_info() ` on :class:`SSHClientChannel` for additional information. """ return self._chan.get_extra_info(name, default) def can_write_eof(self): """Return whether the stream supports :meth:`write_eof`""" return self._chan.can_write_eof() def close(self): """Close the channel .. note:: After this is called, no data can be read or written from any of the streams associated with this channel. """ return self._chan.close() @asyncio.coroutine def drain(self): """Wait until the write buffer on the channel is flushed This method is a coroutine which blocks the caller if the stream is currently paused for writing, returning when enough data has been sent on the channel to allow writing to resume. This can be used to avoid buffering an excessive amount of data in the channel's send buffer. """ return (yield from self._session.drain(self._datatype)) def write(self, data): """Write data to the stream This method writes bytes or characters to the stream. .. note:: Unlike traditional ``asyncio`` stream writers, the data must be supplied as either bytes or a str depending on whether an encoding was specified when the underlying channel was opened. """ return self._chan.write(data, self._datatype) def writelines(self, list_of_data): """Write a collection of data to the stream""" return self._chan.writelines(list_of_data, self._datatype) def write_eof(self): """Write EOF on the channel This method sends an end-of-file indication on the channel, after which no more data can be written. .. note:: On an :class:`SSHServerChannel` where multiple output streams are created, writing EOF on one stream signals EOF for all of them, since it applies to the channel as a whole. """ return self._chan.write_eof() def get_redirect_info(self): """Get information needed to redirect to this SSHWriter""" return self._session, self._datatype class SSHStreamSession: """SSH stream session handler""" def __init__(self): self._chan = None self._conn = None self._encoding = None self._loop = None self._limit = None self._exception = None self._eof_received = False self._connection_lost = False self._recv_buf = {None: []} self._recv_buf_len = 0 self._read_waiters = {None: None} self._read_paused = False self._write_paused = False self._drain_waiters = {None: set()} @asyncio.coroutine def _block_read(self, datatype): """Wait for more data to arrive on the stream""" if self._read_waiters[datatype]: raise RuntimeError('read called while another coroutine is ' 'already waiting to read') try: waiter = asyncio.Future(loop=self._loop) self._read_waiters[datatype] = waiter yield from waiter finally: self._read_waiters[datatype] = None def _unblock_read(self, datatype): """Signal that more data has arrived on the stream""" waiter = self._read_waiters[datatype] if waiter and not waiter.done(): waiter.set_result(None) def _should_block_drain(self, datatype): """Return whether output is still being written to the channel""" # pylint: disable=unused-argument return self._write_paused and not self._connection_lost def _unblock_drain(self, datatype): """Signal that more data can be written on the stream""" if not self._should_block_drain(datatype): for waiter in self._drain_waiters[datatype]: if not waiter.done(): # pragma: no branch waiter.set_result(None) def _should_pause_reading(self): """Return whether to pause reading from the channel""" return self._limit and self._recv_buf_len >= self._limit def _maybe_pause_reading(self): """Pause reading if necessary""" if not self._read_paused and self._should_pause_reading(): self._read_paused = True self._chan.pause_reading() return True else: return False def _maybe_resume_reading(self): """Resume reading if necessary""" if self._read_paused and not self._should_pause_reading(): self._read_paused = False self._chan.resume_reading() return True else: return False def connection_made(self, chan): """Handle a newly opened channel""" self._chan = chan self._conn = chan.get_connection() self._encoding = chan.get_encoding() self._loop = chan.get_loop() self._limit = self._chan.get_recv_window() for datatype in chan.get_read_datatypes(): self._recv_buf[datatype] = [] self._read_waiters[datatype] = None for datatype in chan.get_write_datatypes(): self._drain_waiters[datatype] = set() def connection_lost(self, exc): """Handle an incoming channel close""" self._connection_lost = True self._exception = exc if not self._eof_received: if exc: for datatype in self._read_waiters: self._recv_buf[datatype].append(exc) self.eof_received() for datatype in self._drain_waiters: self._unblock_drain(datatype) def data_received(self, data, datatype): """Handle incoming data on the channel""" self._recv_buf[datatype].append(data) self._recv_buf_len += len(data) self._unblock_read(datatype) self._maybe_pause_reading() def eof_received(self): """Handle an incoming end of file on the channel""" self._eof_received = True for datatype in self._read_waiters: self._unblock_read(datatype) return True def at_eof(self, datatype): """Return whether end of file has been received on the channel""" return self._eof_received and not self._recv_buf[datatype] def pause_writing(self): """Handle a request to pause writing on the channel""" self._write_paused = True def resume_writing(self): """Handle a request to resume writing on the channel""" self._write_paused = False for datatype in self._drain_waiters: self._unblock_drain(datatype) @asyncio.coroutine def read(self, n, datatype, exact): """Read data from the channel""" recv_buf = self._recv_buf[datatype] buf = '' if self._encoding else b'' data = [] while True: while recv_buf and n != 0: if isinstance(recv_buf[0], Exception): if data: break else: raise recv_buf.pop(0) l = len(recv_buf[0]) if n > 0 and l > n: data.append(recv_buf[0][:n]) recv_buf[0] = recv_buf[0][n:] self._recv_buf_len -= n n = 0 break data.append(recv_buf.pop(0)) self._recv_buf_len -= l n -= l if self._maybe_resume_reading(): continue if n == 0 or (n > 0 and data and not exact) or self._eof_received: break yield from self._block_read(datatype) buf = buf.join(data) if n > 0 and exact: raise asyncio.IncompleteReadError(buf, len(buf) + n) return buf @asyncio.coroutine def readuntil(self, separator, datatype): """Read data from the channel until a separator is seen""" if separator is _NEWLINE: separator = '\n' if self._encoding else b'\n' elif not separator: raise ValueError('Separator cannot be empty') seplen = len(separator) recv_buf = self._recv_buf[datatype] buf = '' if self._encoding else b'' buflen = 0 while True: while recv_buf: if isinstance(recv_buf[0], Exception): if buf: raise asyncio.IncompleteReadError(buf, None) else: raise recv_buf.pop(0) buf += recv_buf[0] start = max(buflen + 1 - seplen, 0) idx = buf.find(separator, start) if idx >= 0: idx += seplen recv_buf[0] = buf[idx:] buf = buf[:idx] self._recv_buf_len -= idx if not recv_buf[0]: recv_buf.pop(0) self._maybe_resume_reading() return buf l = len(recv_buf[0]) buflen += l self._recv_buf_len -= l recv_buf.pop(0) if self._maybe_resume_reading(): continue if self._eof_received: raise asyncio.IncompleteReadError(buf, None) yield from self._block_read(datatype) @asyncio.coroutine def drain(self, datatype): """Wait for data written to the channel to drain""" while self._should_block_drain(datatype): try: waiter = asyncio.Future(loop=self._loop) self._drain_waiters[datatype].add(waiter) yield from waiter finally: self._drain_waiters[datatype].remove(waiter) if self._connection_lost: exc = self._exception if not exc and self._write_paused: exc = BrokenPipeError() if exc: raise exc # pylint: disable=raising-bad-type class SSHClientStreamSession(SSHStreamSession, SSHClientSession): """SSH client stream session handler""" class SSHServerStreamSession(SSHStreamSession, SSHServerSession): """SSH server stream session handler""" def __init__(self, session_factory, sftp_factory, allow_scp): super().__init__() self._session_factory = session_factory self._sftp_factory = sftp_factory self._allow_scp = allow_scp and bool(sftp_factory) def shell_requested(self): """Return whether a shell can be requested""" return bool(self._session_factory) def exec_requested(self, command): """Return whether execution of a command can be requested""" return ((self._allow_scp and command.startswith('scp ')) or bool(self._session_factory)) def subsystem_requested(self, subsystem): """Return whether starting a subsystem can be requested""" if subsystem == 'sftp': return bool(self._sftp_factory) else: return bool(self._session_factory) def session_started(self): """Start a session for this newly opened server channel""" command = self._chan.get_command() stdin = SSHReader(self, self._chan) stdout = SSHWriter(self, self._chan) stderr = SSHWriter(self, self._chan, EXTENDED_DATA_STDERR) if self._chan.get_subsystem() == 'sftp': self._chan.set_encoding(None) self._encoding = None handler = run_sftp_server(self._sftp_factory(self._conn), stdin, stdout) elif self._allow_scp and command and command.startswith('scp '): self._chan.set_encoding(None) self._encoding = None handler = run_scp_server(self._sftp_factory(self._conn), command, stdin, stdout, stderr) else: handler = self._session_factory(stdin, stdout, stderr) if asyncio.iscoroutine(handler): self._conn.create_task(handler) def break_received(self, msec): """Handle an incoming break on the channel""" self._recv_buf[None].append(BreakReceived(msec)) self._unblock_read(None) return True def signal_received(self, signal): """Handle an incoming signal on the channel""" self._recv_buf[None].append(SignalReceived(signal)) self._unblock_read(None) def terminal_size_changed(self, *args): """Handle an incoming terminal size change on the channel""" self._recv_buf[None].append(TerminalSizeChanged(*args)) self._unblock_read(None) class SSHSocketStreamSession(SSHStreamSession): """Socket stream session handler""" def __init__(self, handler_factory=None): super().__init__() self._handler_factory = handler_factory def session_started(self): """Start a session for this newly opened socket channel""" if self._handler_factory: handler = self._handler_factory(SSHReader(self, self._chan), SSHWriter(self, self._chan)) if asyncio.iscoroutine(handler): self._conn.create_task(handler) class SSHTCPStreamSession(SSHSocketStreamSession, SSHTCPSession): """TCP stream session handler""" class SSHUNIXStreamSession(SSHSocketStreamSession, SSHUNIXSession): """UNIX stream session handler""" asyncssh-1.11.1/asyncssh/version.py000066400000000000000000000011021320320510200172760ustar00rootroot00000000000000# Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """AsyncSSH version information""" __author__ = 'Ron Frederick' __author_email__ = 'ronf@timeheart.net' __url__ = 'http://asyncssh.timeheart.net' __version__ = '1.11.1' asyncssh-1.11.1/asyncssh/x11.py000066400000000000000000000357031320320510200162400ustar00rootroot00000000000000# Copyright (c) 2016-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """X11 forwarding support""" import asyncio import os import socket import time from collections import namedtuple from .constants import OPEN_CONNECT_FAILED from .forward import SSHForwarder from .listener import create_tcp_forward_listener from .logging import logger from .misc import ChannelOpenError # pylint: disable=bad-whitespace # Xauth address families XAUTH_FAMILY_IPV4 = 0 XAUTH_FAMILY_DECNET = 1 XAUTH_FAMILY_IPV6 = 6 XAUTH_FAMILY_HOSTNAME = 256 XAUTH_FAMILY_WILD = 65535 # Xauth protocol values XAUTH_PROTO_COOKIE = b'MIT-MAGIC-COOKIE-1' XAUTH_COOKIE_LEN = 16 # Xauth lock information XAUTH_LOCK_SUFFIX = '-c' XAUTH_LOCK_TRIES = 5 XAUTH_LOCK_DELAY = 0.2 XAUTH_LOCK_DEAD = 5 # X11 display and port numbers X11_BASE_PORT = 6000 X11_DISPLAY_START = 10 X11_MAX_DISPLAYS = 64 # Host to listen on when doing X11 forwarding X11_LISTEN_HOST = 'localhost' # pylint: enable=bad-whitespace def _parse_display(display): """Parse an X11 display value""" try: host, dpynum = display.rsplit(':', 1) if host.startswith('[') and host.endswith(']'): host = host[1:-1] idx = dpynum.find('.') if idx >= 0: screen = int(dpynum[idx+1:]) dpynum = dpynum[:idx] else: screen = 0 except (ValueError, UnicodeEncodeError): raise ValueError('Invalid X11 display') from None return host, dpynum, screen @asyncio.coroutine def _lookup_host(loop, host, family): """Look up IPv4 or IPv6 addresses of a host name""" try: addrinfo = yield from loop.getaddrinfo(host, 0, family=family, type=socket.SOCK_STREAM) except socket.gaierror: return [] return [ai[4][0] for ai in addrinfo] class SSHXAuthorityEntry(namedtuple('SSHXAuthorityEntry', 'family addr dpynum proto data')): """An entry in an Xauthority file""" def __bytes__(self): """Construct an Xauthority entry""" def _uint16(value): """Construct a big-endian 16-bit unsigned integer""" return value.to_bytes(2, 'big') def _string(data): """Construct a binary string with a 16-bit length""" return _uint16(len(data)) + data return b''.join((_uint16(self.family), _string(self.addr), _string(self.dpynum), _string(self.proto), _string(self.data))) class SSHX11ClientForwarder(SSHForwarder): """X11 forwarding connection handler""" def __init__(self, listener, peer): super().__init__(peer) self._listener = listener self._inpbuf = b'' self._bytes_needed = 12 self._recv_handler = self._recv_prefix self._endian = b'' self._prefix = b'' self._auth_proto_len = 0 self._auth_data_len = 0 self._auth_proto = b'' self._auth_proto_pad = b'' self._auth_data = b'' self._auth_data_pad = b'' def _encode_uint16(self, value): """Encode a 16-bit unsigned integer""" if self._endian == b'B': return bytes((value >> 8, value & 255)) else: return bytes((value & 255, value >> 8)) def _decode_uint16(self, value): """Decode a 16-bit unsigned integer""" if self._endian == b'B': return (value[0] << 8) + value[1] else: return (value[1] << 8) + value[0] @staticmethod def _padded_len(length): """Return length rounded up to the next multiple of 4 bytes""" return ((length + 3) // 4) * 4 @staticmethod def _pad(data): """Pad a string to a multiple of 4 bytes""" length = len(data) % 4 return data + ((4 - length) * b'\00' if length else b'') def _recv_prefix(self, data): """Parse X11 client prefix""" self._endian = data[:1] self._prefix = data self._auth_proto_len = self._decode_uint16(data[6:8]) self._auth_data_len = self._decode_uint16(data[8:10]) self._recv_handler = self._recv_auth_proto self._bytes_needed = self._padded_len(self._auth_proto_len) def _recv_auth_proto(self, data): """Extract X11 auth protocol""" self._auth_proto = data[:self._auth_proto_len] self._auth_proto_pad = data[self._auth_proto_len:] self._recv_handler = self._recv_auth_data self._bytes_needed = self._padded_len(self._auth_data_len) def _recv_auth_data(self, data): """Extract X11 auth data""" self._auth_data = data[:self._auth_data_len] self._auth_data_pad = data[self._auth_data_len:] try: self._auth_data = self._listener.validate_auth(self._auth_data) except KeyError: reason = b'Invalid authentication key\n' response = b''.join((bytes((0, len(reason))), self._encode_uint16(11), self._encode_uint16(0), self._encode_uint16((len(reason) + 3) // 4), self._pad(reason))) try: self.write(response) self.write_eof() except OSError: # pragma: no cover pass self._inpbuf = b'' else: self._inpbuf = (self._prefix + self._auth_proto + self._auth_proto_pad + self._auth_data + self._auth_data_pad) self._recv_handler = None self._bytes_needed = 0 def data_received(self, data, datatype=None): """Handle incoming data from the X11 client""" if self._recv_handler: self._inpbuf += data while self._recv_handler: if len(self._inpbuf) >= self._bytes_needed: data = self._inpbuf[:self._bytes_needed] self._inpbuf = self._inpbuf[self._bytes_needed:] self._recv_handler(data) else: return data = self._inpbuf self._inpbuf = b'' if data: super().data_received(data, datatype) class SSHX11ClientListener: """Client listener used to accept forwarded X11 connections""" def __init__(self, loop, host, dpynum, auth_proto, auth_data): self._host = host self._dpynum = dpynum self._auth_proto = auth_proto self._local_auth = auth_data if host.startswith('/'): self._connect_coro = loop.create_unix_connection self._connect_args = (host + ':' + dpynum,) elif host in ('', 'unix'): self._connect_coro = loop.create_unix_connection self._connect_args = ('/tmp/.X11-unix/X' + dpynum,) else: self._connect_coro = loop.create_connection self._connect_args = (host, X11_BASE_PORT + int(dpynum)) self._remote_auth = {} self._channel = {} def attach(self, display, chan, single_connection): """Attach a channel to this listener""" host, dpynum, screen = _parse_display(display) if self._host != host or self._dpynum != dpynum: raise ValueError('Already forwarding to another X11 display') remote_auth = os.urandom(len(self._local_auth)) self._remote_auth[chan] = remote_auth self._channel[remote_auth] = chan, single_connection return self._auth_proto, remote_auth, screen def detach(self, chan): """Detach a channel from this listener""" try: remote_auth = self._remote_auth.pop(chan) del self._channel[remote_auth] except KeyError: pass return self._remote_auth == {} @asyncio.coroutine def forward_connection(self): """Forward an incoming connection to the local X server""" try: _, peer = yield from self._connect_coro(SSHForwarder, *self._connect_args) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHX11ClientForwarder(self, peer) def validate_auth(self, remote_auth): """Validate client auth and enforce single connection flag""" chan, single_connection = self._channel[remote_auth] if single_connection: del self._channel[remote_auth] del self._remote_auth[chan] return self._local_auth class SSHX11ServerListener: """Server listener used to forward X11 connections""" def __init__(self, tcp_listener, display): self._tcp_listener = tcp_listener self._display = display self._channels = set() def attach(self, chan, screen): """Attach a channel to this listener and return its display""" self._channels.add(chan) return '%s.%s' % (self._display, screen) def detach(self, chan): """Detach a channel from this listener""" try: self._channels.remove(chan) except KeyError: pass if not self._channels: self._tcp_listener.close() self._tcp_listener = None return True else: return False def get_xauth_path(auth_path): """Compute the path to the Xauthority file""" if not auth_path: auth_path = os.environ.get('XAUTHORITY') if not auth_path: auth_path = os.path.join(os.path.expanduser('~'), '.Xauthority') return auth_path def walk_xauth(auth_path): """Walk the entries in an Xauthority file""" def _read_bytes(n): """Read exactly n bytes""" data = auth_file.read(n) if len(data) != n: raise EOFError return data def _read_uint16(): """Read a 16-bit unsigned integer""" return int.from_bytes(_read_bytes(2), 'big') def _read_string(): """Read a string""" return _read_bytes(_read_uint16()) try: with open(auth_path, 'rb') as auth_file: while True: try: family = _read_uint16() except EOFError: break try: yield SSHXAuthorityEntry(family, _read_string(), _read_string(), _read_string(), _read_string()) except EOFError: raise ValueError('Incomplete Xauthority entry') from None except OSError: pass @asyncio.coroutine def lookup_xauth(loop, auth_path, host, dpynum): """Look up Xauthority data for the specified display""" auth_path = get_xauth_path(auth_path) if host.startswith('/') or host in ('', 'unix', 'localhost'): host = socket.gethostname() dpynum = dpynum.encode('ascii') ipv4_addrs = [] ipv6_addrs = [] for entry in walk_xauth(auth_path): if entry.dpynum and entry.dpynum != dpynum: continue if entry.family == XAUTH_FAMILY_IPV4: if not ipv4_addrs: ipv4_addrs = yield from _lookup_host(loop, host, socket.AF_INET) addr = socket.inet_ntop(socket.AF_INET, entry.addr) match = addr in ipv4_addrs elif entry.family == XAUTH_FAMILY_IPV6: if not ipv6_addrs: ipv6_addrs = yield from _lookup_host(loop, host, socket.AF_INET6) addr = socket.inet_ntop(socket.AF_INET6, entry.addr) match = addr in ipv6_addrs elif entry.family == XAUTH_FAMILY_HOSTNAME: match = entry.addr == host.encode('idna') elif entry.family == XAUTH_FAMILY_WILD: match = True else: match = False if match: return entry.proto, entry.data logger.warning('No xauth entry found for display: using random auth') return XAUTH_PROTO_COOKIE, os.urandom(XAUTH_COOKIE_LEN) @asyncio.coroutine def update_xauth(loop, auth_path, host, dpynum, auth_proto, auth_data): """Update Xauthority data for the specified display""" if host.startswith('/') or host in ('', 'unix', 'localhost'): host = socket.gethostname() host = host.encode('idna') dpynum = str(dpynum).encode('ascii') auth_path = get_xauth_path(auth_path) new_auth_path = auth_path + XAUTH_LOCK_SUFFIX new_file = None try: if time.time() - os.stat(new_auth_path).st_ctime > XAUTH_LOCK_DEAD: os.unlink(new_auth_path) except FileNotFoundError: pass for _ in range(XAUTH_LOCK_TRIES): try: new_file = open(new_auth_path, 'xb') except FileExistsError: yield from asyncio.sleep(XAUTH_LOCK_DELAY, loop=loop) else: break if not new_file: raise ValueError('Unable to acquire Xauthority lock') new_entry = SSHXAuthorityEntry(XAUTH_FAMILY_HOSTNAME, host, dpynum, auth_proto, auth_data) new_file.write(bytes(new_entry)) for entry in walk_xauth(auth_path): if (entry.family != new_entry.family or entry.addr != new_entry.addr or entry.dpynum != new_entry.dpynum): new_file.write(bytes(entry)) new_file.close() os.replace(new_auth_path, auth_path) @asyncio.coroutine def create_x11_client_listener(loop, display, auth_path): """Create a listener to accept X11 connections forwarded over SSH""" host, dpynum, _ = _parse_display(display) auth_proto, auth_data = yield from lookup_xauth(loop, auth_path, host, dpynum) return SSHX11ClientListener(loop, host, dpynum, auth_proto, auth_data) @asyncio.coroutine def create_x11_server_listener(conn, loop, auth_path, auth_proto, auth_data): """Create a listener to forward X11 connections over SSH""" for dpynum in range(X11_DISPLAY_START, X11_MAX_DISPLAYS): try: tcp_listener = yield from create_tcp_forward_listener( conn, loop, conn.create_x11_connection, X11_LISTEN_HOST, X11_BASE_PORT + dpynum) except OSError: continue display = '%s:%d' % (X11_LISTEN_HOST, dpynum) try: yield from update_xauth(loop, auth_path, X11_LISTEN_HOST, dpynum, auth_proto, auth_data) except ValueError: tcp_listener.close() break return SSHX11ServerListener(tcp_listener, display) return None asyncssh-1.11.1/docs/000077500000000000000000000000001320320510200143425ustar00rootroot00000000000000asyncssh-1.11.1/docs/_templates/000077500000000000000000000000001320320510200164775ustar00rootroot00000000000000asyncssh-1.11.1/docs/_templates/sidebarbottom.html000066400000000000000000000007101320320510200222210ustar00rootroot00000000000000

Change Log

Contributing

API Documentation

Source on PyPI

Source on GitHub

Issue Tracker

Search

asyncssh-1.11.1/docs/_templates/sidebartop.html000066400000000000000000000001231320320510200215150ustar00rootroot00000000000000 AsyncSSH
Version {{version}}

asyncssh-1.11.1/docs/api.rst000066400000000000000000001546601320320510200156610ustar00rootroot00000000000000.. module:: asyncssh .. _API: API Documentation ***************** Overview ======== The AsyncSSH API is modeled after the new Python ``asyncio`` framework, with a :func:`create_connection` coroutine to create an SSH client and a :func:`create_server` coroutine to create an SSH server. Like the ``asyncio`` framework, these calls take a parameter of a factory which creates protocol objects to manage the connections once they are open. For AsyncSSH, :func:`create_connection` should be passed a ``client_factory`` which returns objects derived from :class:`SSHClient` and :func:`create_server` should be passed a ``server_factory`` which returns objects derived from :class:`SSHServer`. In addition, each connection will have an associated :class:`SSHClientConnection` or :class:`SSHServerConnection` object passed to the protocol objects which can be used to perform actions on the connection. For client connections, authentication can be performed by passing in a username and password or SSH keys as arguments to :func:`create_connection` or by implementing handler methods on the :class:`SSHClient` object which return credentials when the server requests them. If no credentials are provided, AsyncSSH automatically attempts to send the username of the local user and the keys found in their :file:`.ssh` subdirectory. A list of expected server host keys can also be specified, with AsyncSSH defaulting to looking for matching lines in the user's :file:`.ssh/known_hosts` file. For server connections, handlers can be implemented on the :class:`SSHServer` object to return which authentication methods are supported and to validate credentials provided by clients. Once an SSH client connection is established and authentication is successful, multiple simultaneous channels can be opened on it. This is accomplished calling methods such as :meth:`create_session() `, :meth:`create_connection() `, and :meth:`create_unix_connection() ` on the :class:`SSHClientConnection` object. The client can also set up listeners on remote TCP ports and UNIX domain sockets by calling :meth:`create_server() ` and :meth:`create_unix_server() `. All of these methods take ``session_factory`` arguments that return :class:`SSHClientSession`, :class:`SSHTCPSession`, or :class:`SSHUNIXSession` objects used to manage the channels once they are open. Alternately, channels can be opened using :meth:`open_session() `, :meth:`open_connection() `, or :meth:`open_unix_connection() `, which return :class:`SSHReader` and :class:`SSHWriter` objects that can be used to perform I/O on the channel. The methods :meth:`start_server() ` and :meth:`start_unix_server() ` can be used to set up listeners on remote TCP ports or UNIX domain sockets and get back these :class:`SSHReader` and :class:`SSHWriter` objects in a callback when new connections are opened. SSH client sessions can also be opened by calling :meth:`create_process() `. This returns a :class:`SSHClientProcess` object which has members ``stdin``, ``stdout``, and ``stderr`` which are :class:`SSHReader` and :class:`SSHWriter` objects. This API also makes it very easy to redirect input and output from the remote process to local files, pipes, sockets, or other :class:`SSHReader` and :class:`SSHWriter` objects. In cases where you just want to run a remote process to completion and get back an object containing captured output and exit status, the :meth:`run() ` method can be used. It returns an :class:`SSHCompletedProcess` with the results of the run, or can be set up to raise :class:`ProcessError` if the process exits with a non-zero exit status. The client can also set up TCP port forwarding by calling :meth:`forward_local_port() ` or :meth:`forward_remote_port() ` and UNIX domain socket forwarding by calling :meth:`forward_local_path() ` or :meth:`forward_remote_path() `. In these cases, data transfer on the channels is managed automatically by AsyncSSH whenever new connections are opened, so custom session objects are not required. When an SSH server receives a new connection and authentication is successful, handlers such as :meth:`session_requested() `, :meth:`connection_requested() `, :meth:`unix_connection_requested() `, :meth:`server_requested() `, and :meth:`unix_server_requested() ` on the associated :class:`SSHServer` object will be called when clients attempt to open channels or set up listeners. These methods return coroutines which can set up the requested sessions or connections, returning :class:`SSHServerSession` or :class:`SSHTCPSession` objects or handler functions that accept :class:`SSHReader` and :class:`SSHWriter` objects as arguments which manage the channels once they are open. To better support interactive server applications, AsyncSSH defaults to providing echoing of input and basic line editing capabilities when an inbound SSH session requests a pseudo-terminal. This behavior can be disabled by setting the ``line_editor`` argument to ``False`` when starting up an SSH server. When this feature is enabled, server sessions can enable or disable line mode using the :meth:`set_line_mode() ` method of :class:`SSHLineEditorChannel`. They can also enable or disable input echoing using the :meth:`set_echo() ` method. Each session object also has an associated :class:`SSHClientChannel`, :class:`SSHServerChannel`, or :class:`SSHTCPChannel` object passed to it which can be used to perform actions on the channel. These channel objects provide a superset of the functionality found in ``asyncio`` transport objects. In addition to the above functions and classes, helper functions for importing public and private keys can be found below under :ref:`PublicKeySupport`, exceptions can be found under :ref:`Exceptions`, supported algorithms can be found under :ref:`SupportedAlgorithms`, and some useful constants can be found under :ref:`Constants`. Main Functions ============== create_connection ----------------- .. autofunction:: create_connection create_server ------------- .. autofunction:: create_server connect ------- .. autofunction:: connect listen ------ .. autofunction:: listen scp --- .. autofunction:: scp Main Classes ============ SSHClient --------- .. autoclass:: SSHClient ================================== = General connection handlers ================================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: debug_msg_received ================================== = ==================================== = General authentication handlers ==================================== = .. automethod:: auth_banner_received .. automethod:: auth_completed ==================================== = ========================================= = Public key authentication handlers ========================================= = .. automethod:: public_key_auth_requested ========================================= = ========================================= = Password authentication handlers ========================================= = .. automethod:: password_auth_requested .. automethod:: password_change_requested .. automethod:: password_changed .. automethod:: password_change_failed ========================================= = ============================================ = Keyboard-interactive authentication handlers ============================================ = .. automethod:: kbdint_auth_requested .. automethod:: kbdint_challenge_received ============================================ = SSHServer --------- .. autoclass:: SSHServer ================================== = General connection handlers ================================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: debug_msg_received ================================== = =============================== = General authentication handlers =============================== = .. automethod:: begin_auth =============================== = ====================================== = GSSAPI authentication handlers ====================================== = .. automethod:: validate_gss_principal ====================================== = ========================================= = Public key authentication handlers ========================================= = .. automethod:: public_key_auth_supported .. automethod:: validate_public_key .. automethod:: validate_ca_key ========================================= = ======================================= = Password authentication handlers ======================================= = .. automethod:: password_auth_supported .. automethod:: validate_password .. automethod:: change_password ======================================= = ============================================ = Keyboard-interactive authentication handlers ============================================ = .. automethod:: kbdint_auth_supported .. automethod:: get_kbdint_challenge .. automethod:: validate_kbdint_response ============================================ = ========================================= = Channel session open handlers ========================================= = .. automethod:: session_requested .. automethod:: connection_requested .. automethod:: unix_connection_requested .. automethod:: server_requested .. automethod:: unix_server_requested ========================================= = Connection Classes ================== SSHClientConnection ------------------- .. autoclass:: SSHClientConnection() ============================== = General connection methods ============================== = .. automethod:: get_extra_info .. automethod:: send_debug ============================== = ================================================================================================================================= = Client session open methods ================================================================================================================================= = .. automethod:: create_session .. automethod:: open_session .. automethod:: create_process(*args, bufsize=io.DEFAULT_BUFFER_SIZE, input=None, stdin=PIPE, stdout=PIPE, stderr=PIPE, **kwargs) .. automethod:: run(*args, check=False, **kwargs) .. automethod:: start_sftp_client .. automethod:: create_ssh_connection .. automethod:: connect_ssh ================================================================================================================================= = ====================================== = Client connection open methods ====================================== = .. automethod:: create_connection .. automethod:: open_connection .. automethod:: create_server .. automethod:: start_server .. automethod:: create_unix_connection .. automethod:: open_unix_connection .. automethod:: create_unix_server .. automethod:: start_unix_server ====================================== = =================================== = Client forwarding methods =================================== = .. automethod:: forward_connection .. automethod:: forward_local_port .. automethod:: forward_local_path .. automethod:: forward_remote_port .. automethod:: forward_remote_path =================================== = =========================== = Connection close methods =========================== = .. automethod:: abort .. automethod:: close .. automethod:: disconnect .. automethod:: wait_closed =========================== = SSHServerConnection ------------------- .. autoclass:: SSHServerConnection() ============================== = General connection methods ============================== = .. automethod:: get_extra_info .. automethod:: send_debug ============================== = ============================================ = Server authentication methods ============================================ = .. automethod:: send_auth_banner .. automethod:: set_authorized_keys .. automethod:: get_key_option .. automethod:: check_key_permission .. automethod:: get_certificate_option .. automethod:: check_certificate_permission ============================================ = ====================================== = Server connection open methods ====================================== = .. automethod:: create_connection .. automethod:: open_connection .. automethod:: create_unix_connection .. automethod:: open_unix_connection ====================================== = ======================================= = Server forwarding methods ======================================= = .. automethod:: forward_connection .. automethod:: forward_unix_connection ======================================= = ===================================== = Server channel creation methods ===================================== = .. automethod:: create_server_channel .. automethod:: create_tcp_channel .. automethod:: create_unix_channel ===================================== = =========================== = Connection close methods =========================== = .. automethod:: abort .. automethod:: close .. automethod:: disconnect .. automethod:: wait_closed =========================== = Process Classes =============== SSHClientProcess ---------------- .. autoclass:: SSHClientProcess ============================== = Client process attributes ============================== = .. autoattribute:: channel .. autoattribute:: env .. autoattribute:: command .. autoattribute:: subsystem .. autoattribute:: stdin .. autoattribute:: stdout .. autoattribute:: stderr .. autoattribute:: exit_status .. autoattribute:: exit_signal ============================== = ==================================== = Other client process methods ==================================== = .. automethod:: redirect .. automethod:: communicate .. automethod:: wait .. automethod:: change_terminal_size .. automethod:: send_break .. automethod:: send_signal ==================================== = ============================ = Client process close methods ============================ = .. automethod:: terminate .. automethod:: kill .. automethod:: close .. automethod:: wait_closed ============================ = SSHServerProcess ---------------- .. autoclass:: SSHServerProcess ============================== = Server process attributes ============================== = .. autoattribute:: channel .. autoattribute:: env .. autoattribute:: command .. autoattribute:: subsystem .. autoattribute:: stdin .. autoattribute:: stdout .. autoattribute:: stderr ============================== = ================================= = Other server process methods ================================= = .. automethod:: get_terminal_type .. automethod:: get_terminal_size .. automethod:: get_terminal_mode .. automethod:: redirect ================================= = ================================ = Server process close methods ================================ = .. automethod:: exit .. automethod:: exit_with_signal .. automethod:: close .. automethod:: wait_closed ================================ = SSHCompletedProcess ------------------- .. autoclass:: SSHCompletedProcess() Session Classes =============== SSHClientSession ---------------- .. autoclass:: SSHClientSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = ==================================== = Other client session handlers ==================================== = .. automethod:: xon_xoff_requested .. automethod:: exit_status_received .. automethod:: exit_signal_received ==================================== = SSHServerSession ---------------- .. autoclass:: SSHServerSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = =================================== = Server session open handlers =================================== = .. automethod:: pty_requested .. automethod:: shell_requested .. automethod:: exec_requested .. automethod:: subsystem_requested =================================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = ===================================== = Other server session handlers ===================================== = .. automethod:: break_received .. automethod:: signal_received .. automethod:: terminal_size_changed ===================================== = SSHTCPSession ------------- .. autoclass:: SSHTCPSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = SSHUNIXSession -------------- .. autoclass:: SSHUNIXSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = Channel Classes =============== SSHClientChannel ---------------- .. autoclass:: SSHClientChannel() =============================== = General channel info methods =============================== = .. automethod:: get_extra_info .. automethod:: get_environment .. automethod:: get_command .. automethod:: get_subsystem =============================== = ============================== = Client channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = Client channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================= = ===================================== = Other client channel methods ===================================== = .. automethod:: get_exit_status .. automethod:: get_exit_signal .. automethod:: change_terminal_size .. automethod:: send_break .. automethod:: send_signal .. automethod:: kill .. automethod:: terminate ===================================== = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: wait_closed ============================= = SSHServerChannel ---------------- .. autoclass:: SSHServerChannel() =============================== = General channel info methods =============================== = .. automethod:: get_extra_info .. automethod:: get_environment .. automethod:: get_command .. automethod:: get_subsystem =============================== = ================================= = Server channel info methods ================================= = .. automethod:: get_terminal_type .. automethod:: get_terminal_size .. automethod:: get_terminal_mode .. automethod:: get_x11_display .. automethod:: get_agent_path ================================= = ============================== = Server channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = Server channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_stderr .. automethod:: writelines_stderr .. automethod:: write_eof ======================================= = ================================= = Other server channel methods ================================= = .. automethod:: set_xon_xoff .. automethod:: exit .. automethod:: exit_with_signal ================================= = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: wait_closed ============================= = SSHLineEditorChannel -------------------- .. autoclass:: SSHLineEditorChannel() ============================= = Line editor methods ============================= = .. automethod:: set_line_mode .. automethod:: set_echo ============================= = SSHTCPChannel ------------- .. autoclass:: SSHTCPChannel() ============================== = General channel info methods ============================== = .. automethod:: get_extra_info ============================== = ============================== = General channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = General channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================= = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: wait_closed ============================= = SSHUNIXChannel -------------- .. autoclass:: SSHUNIXChannel() ============================== = General channel info methods ============================== = .. automethod:: get_extra_info ============================== = ============================== = General channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = General channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================= = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: wait_closed ============================= = Listener Classes ================ SSHListener ----------- .. autoclass:: SSHListener() =========================== = .. automethod:: get_port .. automethod:: close .. automethod:: wait_closed =========================== = Stream Classes ============== SSHReader --------- .. autoclass:: SSHReader() ============================== = .. autoattribute:: channel .. automethod:: get_extra_info .. automethod:: at_eof .. automethod:: read .. automethod:: readline .. automethod:: readuntil .. automethod:: readexactly ============================== = SSHWriter --------- .. autoclass:: SSHWriter() ============================== = .. autoattribute:: channel .. automethod:: get_extra_info .. automethod:: can_write_eof .. automethod:: close .. automethod:: drain .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ============================== = SFTP Support ============ SFTPClient ---------- .. autoclass:: SFTPClient() ===================== = File transfer methods ===================== = .. automethod:: get .. automethod:: put .. automethod:: copy .. automethod:: mget .. automethod:: mput .. automethod:: mcopy ===================== = ========================================================================================== = File access methods ========================================================================================== = .. automethod:: open(path, mode='r', attrs=SFTPAttrs(), encoding='utf-8', errors='strict') .. automethod:: truncate .. automethod:: rename .. automethod:: posix_rename .. automethod:: remove .. automethod:: unlink .. automethod:: readlink .. automethod:: symlink .. automethod:: link .. automethod:: realpath ========================================================================================== = ============================= = File attribute access methods ============================= = .. automethod:: stat .. automethod:: lstat .. automethod:: setstat .. automethod:: statvfs .. automethod:: chown .. automethod:: chmod .. automethod:: utime .. automethod:: exists .. automethod:: lexists .. automethod:: getatime .. automethod:: getmtime .. automethod:: getsize .. automethod:: isdir .. automethod:: isfile .. automethod:: islink ============================= = ============================================== = Directory access methods ============================================== = .. automethod:: chdir .. automethod:: getcwd .. automethod:: mkdir(path, attrs=SFTPAttrs()) .. automethod:: rmdir .. automethod:: readdir .. automethod:: listdir .. automethod:: glob ============================================== = =========================== = Cleanup methods =========================== = .. automethod:: exit .. automethod:: wait_closed =========================== = SFTPClientFile -------------- .. autoclass:: SFTPClientFile() ================================================ = .. automethod:: read .. automethod:: write .. automethod:: seek(offset, from_what=SEEK_SET) .. automethod:: tell .. automethod:: stat .. automethod:: setstat .. automethod:: statvfs .. automethod:: truncate .. automethod:: chown .. automethod:: chmod .. automethod:: utime .. automethod:: fsync .. automethod:: close ================================================ = SFTPServer ---------- .. autoclass:: SFTPServer ================================== = Path remapping and display methods ================================== = .. automethod:: format_user .. automethod:: format_group .. automethod:: format_longname .. automethod:: map_path .. automethod:: reverse_map_path ================================== = ============================ = File access methods ============================ = .. automethod:: open .. automethod:: close .. automethod:: read .. automethod:: write .. automethod:: rename .. automethod:: posix_rename .. automethod:: remove .. automethod:: readlink .. automethod:: symlink .. automethod:: link .. automethod:: realpath ============================ = ============================= = File attribute access methods ============================= = .. automethod:: stat .. automethod:: lstat .. automethod:: fstat .. automethod:: setstat .. automethod:: fsetstat .. automethod:: statvfs .. automethod:: fstatvfs ============================= = ======================== = Directory access methods ======================== = .. automethod:: listdir .. automethod:: mkdir .. automethod:: rmdir ======================== = ===================== = Cleanup methods ===================== = .. automethod:: exit ===================== = SFTPAttrs --------- .. autoclass:: SFTPAttrs() SFTPVFSAttrs ------------ .. autoclass:: SFTPVFSAttrs() SFTPName -------- .. autoclass:: SFTPName() .. index:: Public key and certificate support .. _PublicKeySupport: Public Key Support ================== AsyncSSH has extensive public key and certificate support. Supported public key types include DSA, RSA, and ECDSA. In addition, Ed25519 keys are supported if the libnacl package and libsodium library are installed. Supported certificate types include OpenSSH version 01 certificates for DSA, RSA, ECDSA, and Ed25519 keys and X.509 certificates for DSA, RSA, and ECDSA keys. Support is also available for the certificate critical options of force-command and source-address and the extensions permit-X11-forwarding, permit-agent-forwarding, permit-port-forwarding, and permit-pty in OpenSSH certificates. Several public key and certificate formats are supported including PKCS#1 and PKCS#8 DER and PEM, OpenSSH, RFC4716, and X.509 DER and PEM formats. PEM and PKCS#8 password-based encryption of private keys is supported, as is OpenSSH private key encryption when the bcrypt package is installed. .. index:: Specifying private keys .. _SpecifyingPrivateKeys: Specifying private keys ----------------------- Private keys may be passed into AsyncSSH in a variety of forms. The simplest option is to pass the name of a file to read one or more private keys from. An alternate form involves passing in a list of values which can be either a reference to a private key or a tuple containing a reference to a private key and a reference to a corresponding certificate or certificate chain. Key references can either be the name of a file to load a key from, a byte string to import as a key, or an already loaded :class:`SSHKey` private key. See the function :func:`import_private_key` for the list of supported private key formats. Certificate references can be the name of a file to load a certificate from, a byte string to import as a certificate, an already loaded :class:`SSHCertificate`, or ``None`` if no certificate should be associated with the key. Whenever a filename is provided to read the private key from, an attempt is made to load a corresponding certificate or certificate chain from a file constructed by appending '-cert.pub' to the end of the name. X.509 certificates may also be provided in the same file as the private key, when using DER or PEM format. When using X.509 certificates, a list of certificates can also be provided. These certificates should form a trust chain from a user or host certificate up to some self-signed root certificate authority which is trusted by the remote system. New private keys can be generated using the :func:`generate_private_key` function. The resulting :class:`SSHKey` objects have methods which can then be used to export the generated keys in several formats for consumption by other tools, as well as methods for generating new OpenSSH or X.509 certificates. .. index:: Specifying public keys .. _SpecifyingPublicKeys: Specifying public keys ---------------------- Public keys may be passed into AsyncSSH in a variety of forms. The simplest option is to pass the name of a file to read one or more public keys from. An alternate form involves passing in a list of values each of which can be either the name of a file to load a key from, a byte string to import it from, or an already loaded :class:`SSHKey` public key. See the function :func:`import_public_key` for the list of supported public key formats. .. index:: Specifying certificates .. _SpecifyingCertificates: Specifying certificates ----------------------- Certificates may be passed into AsyncSSH in a variety of forms. The simplest option is to pass the name of a file to read one or more certificates from. An alternate form involves passing in a list of values each of which can be either the name of a file to load a certificate from, a byte string to import it from, or an already loaded :class:`SSHCertificate` object. See the function :func:`import_certificate` for the list of supported certificate formats. .. index:: Specifying X.509 subject names .. _SpecifyingX509Subjects: Specifying X.509 subject names ------------------------------ X.509 certificate subject names may be specified in place of public keys or certificates in authorized_keys and known_hosts files, allowing any X.509 certificate which matches that subject name to be considered a known host or authorized key. The syntax supported for this is compatible with PKIX-SSH, which adds X.509 certificate support to OpenSSH. To specify a subject name pattern instead of a specific certificate, base64-encoded certificate data should be replaced with the string 'Subject:' followed by a a comma-separated list of X.509 relative distinguished name components. AsyncSSH extends the PKIX-SSH syntax to also support matching on a prefix of a subject name. To indicate this, a partial subject name can be specified which ends in ',*'. Any subject which matches the relative distinguished names listed before the ",*" will be treated as a match, even if the certificate provided has additional relative distinguished names following what was matched. .. index:: Specifying X.509 purposes .. _SpecifyingX509Purposes: Specifying X.509 purposes ------------------------- When performing X.509 certificate authentication, AsyncSSH can be passed in an allowed set of ExtendedKeyUsage purposes. Purposes are matched in X.509 certificates as OID values, but AsyncSSH also allows the following well-known purpose values to be specified by name: ================= ================== Name OID ================= ================== serverAuth 1.3.6.1.5.5.7.3.1 clientAuth 1.3.6.1.5.5.7.3.2 secureShellClient 1.3.6.1.5.5.7.3.20 secureShellServer 1.3.6.1.5.5.7.3.21 ================= ================== Values not in the list above can be specified directly by OID as a dotted numeric string value. Either a single value or a list of values can be provided. The check succeeds if any of the specified values are present in the certificate's ExtendedKeyUsage. It will also succeed if the certificate does not contain an ExtendedKeyUsage or if the ExtendedKeyUsage contains the OID 2.5.29.37.0, which indicates the certificate can be used for any purpose. This check defaults to requiring a purpose of 'secureShellCient' for client certificates and 'secureShellServer' for server certificates and should not normally need to be changed. However, certificates which contain other purposes can be supported by providing alternate values to match against, or by passing in the purpose 'any' to disable this checking. .. index:: Specifying time values .. _SpecifyingTimeValues: Specifying time values ---------------------- When generating certificates, an optional validity interval can be specified using the ``valid_after`` and ``valid_before`` parameters to the :meth:`generate_user_certificate() ` and :meth:`generate_host_certificate() ` methods. These values can be specified in any of the following ways: * An int or float UNIX epoch time, such as what is returned by :func:`time.time`. * A :class:`datetime.datetime` value. * A string value of ``now`` to request the current time. * A string value in the form ``YYYYMMDD`` to specify an absolute date. * A string value in the form ``YYYYMMDDHHMMSS`` to specify an absolute date and time. * A relative time made up of a mix of positive or negative numbers and the letters 'w', 'd', 'h', 'm', and 's', representing an offset from the current time in weeks, days, hours, minutes, or seconds, respectively. Multiple of these values can be included. For instance, '+1w2d3h' means 1 week, 2 days, and 3 hours in the future. SSHKey ------ .. autoclass:: SSHKey() ============================================== = .. automethod:: get_algorithm .. automethod:: get_comment .. automethod:: set_comment .. automethod:: convert_to_public .. automethod:: generate_user_certificate .. automethod:: generate_host_certificate .. automethod:: generate_x509_user_certificate .. automethod:: generate_x509_host_certificate .. automethod:: generate_x509_ca_certificate .. automethod:: export_private_key .. automethod:: export_public_key .. automethod:: write_private_key .. automethod:: write_public_key .. automethod:: append_private_key .. automethod:: append_public_key ============================================== = SSHKeyPair ---------- .. autoclass:: SSHKeyPair() ============================= = .. automethod:: get_key_type .. automethod:: get_algorithm .. automethod:: get_comment .. automethod:: set_comment ============================= = SSHCertificate -------------- .. autoclass:: SSHCertificate() ================================== = .. automethod:: get_algorithm .. automethod:: get_comment .. automethod:: set_comment .. automethod:: export_certificate .. automethod:: write_certificate .. automethod:: append_certificate ================================== = generate_private_key -------------------- .. autofunction:: generate_private_key import_private_key ------------------ .. autofunction:: import_private_key import_public_key ----------------- .. autofunction:: import_public_key import_certificate ------------------ .. autofunction:: import_certificate read_private_key ---------------- .. autofunction:: read_private_key read_public_key --------------- .. autofunction:: read_public_key read_certificate ---------------- .. autofunction:: read_certificate read_private_key_list --------------------- .. autofunction:: read_private_key_list read_public_key_list -------------------- .. autofunction:: read_public_key_list read_certificate_list --------------------- .. autofunction:: read_certificate_list load_keypairs ------------- .. autofunction:: load_keypairs load_public_keys ---------------- .. autofunction:: load_public_keys load_certificates ----------------- .. autofunction:: load_certificates .. index:: SSH agent support SSH Agent Support ================= AsyncSSH supports the ability to use private keys managed by the OpenSSH ssh-agent on UNIX systems. It can connect via a UNIX domain socket to the agent and offload all private key operations to it, avoiding the need to read these keys into AsyncSSH itself. An ssh-agent is automatically used in :func:`create_connection` when a valid ``SSH_AUTH_SOCK`` is set in the environment. An alternate path to the agent can be specified via the ``agent_path`` argument to this function. An ssh-agent can also be accessed directly from AsyncSSH by calling :func:`connect_agent`. When successful, this function returns an :class:`SSHAgentClient` which can be used to get a list of available keys, add and remove keys, and lock and unlock access to this agent. SSH agent forwarding may be enabled when making outbound SSH connections by specifying the ``agent_forwarding`` argument when calling :func:`create_connection`, allowing processes running on the server to tunnel requests back over the SSH connection to the client's ssh-agent. Agent forwarding can be enabled when starting an SSH server by specifying the ``agent_forwarding`` argument when calling :func:`create_server`. In this case, the client's ssh-agent can be accessed from the server by passing the :class:`SSHServerConnection` as the argument to :func:`connect_agent` instead of a local path. Alternately, when an :class:`SSHServerChannel` has been opened, the :meth:`get_agent_path() ` method may be called on it to get a path to a UNIX domain socket which can be passed as the ``SSH_AUTH_SOCK`` to local applications which need this access. Any requests sent to this socket are forwarded over the SSH connection to the client's ssh-agent. SSHAgentClient -------------- .. autoclass:: SSHAgentClient() ===================================== = .. automethod:: get_keys .. automethod:: add_keys .. automethod:: add_smartcard_keys .. automethod:: remove_keys .. automethod:: remove_smartcard_keys .. automethod:: remove_all .. automethod:: lock .. automethod:: unlock .. automethod:: query_extensions .. automethod:: close ===================================== = SSHAgentKeyPair --------------- .. autoclass:: SSHAgentKeyPair() ============================= = .. automethod:: get_key_type .. automethod:: get_algorithm .. automethod:: get_comment .. automethod:: set_comment .. automethod:: remove ============================= = connect_agent ------------- .. autofunction:: connect_agent .. index:: Known hosts .. _KnownHosts: Known Hosts =========== AsyncSSH supports OpenSSH-style known_hosts files, including both plain and hashed host entries. Regular and negated host patterns are supported in plain entries. AsyncSSH also supports the ``@cert_authority`` marker to indicate keys and certificates which should be trusted as certificate authorities and the ``@revoked`` marker to indicate keys and certificates which should be explicitly reported as no longer trusted. .. index:: Specifying known hosts .. _SpecifyingKnownHosts: Specifying known hosts ---------------------- Known hosts may be passed into AsyncSSH via the ``known_hosts`` argument to :func:`create_connection`. This can be the name of a file containing a list of known hosts, a byte string containing a list of known hosts, or an :class:`SSHKnownHosts` object which was previously imported from a string by calling :func:`import_known_hosts` or read from a file by calling :func:`read_known_hosts`. In all of these cases, the host patterns in the list will be compared against the target host, address, and port being connected to and the matching trusted host keys, trusted CA keys, revoked keys, trusted X.509 certificates, revoked X.509 certificates, trusted X.509 subject names, and revoked X.509 subject names will be returned. Alternately, a function can be passed in as the ``known_hosts`` argument that accepts a target host, address, and port and returns lists containing trusted host keys, trusted CA keys, revoked keys, trusted X.509 certificates, revoked X.509 certificates, trusted X.509 subject names, and revoked X.509 subject names. If no matching is required and the caller already knows exactly what the above values should be, these seven lists can also be provided directly in the ``known_hosts`` argument. See :ref:`SpecifyingPublicKeys` for the allowed form of public key values which can be provided, :ref:`SpecifyingCertificates` for the allowed form of certificates, and :ref:`SpecifyingX509Subjects` for the allowed form of X.509 subject names. SSHKnownHosts ------------- .. autoclass:: SSHKnownHosts() ===================== = .. automethod:: match ===================== = import_known_hosts ------------------ .. autofunction:: import_known_hosts read_known_hosts ---------------- .. autofunction:: read_known_hosts match_known_hosts ----------------- .. autofunction:: match_known_hosts .. index:: Authorized keys .. _AuthorizedKeys: Authorized Keys =============== AsyncSSH supports OpenSSH-style authorized_keys files, including the cert-authority option to validate user certificates, enforcement of from and principals options to restrict key matching, enforcement of no-X11-forwarding, no-agent-forwarding, no-pty, no-port-forwarding, and permitopen options, and support for command and environment options. .. index:: Specifying authorized keys .. _SpecifyingAuthorizedKeys: Specifying authorized keys -------------------------- Authorized keys may be passed into AsyncSSH via the ``authorized_client_keys`` argument to :func:`create_server` or by calling :meth:`set_authorized_keys() ` on the :class:`SSHServerConnection` from within the :meth:`begin_auth() ` method in :class:`SSHServer`. Authorized keys can be provided as either the name of a file to read the keys from or an :class:`SSHAuthorizedKeys` object which was previously imported from a string by calling :func:`import_authorized_keys` or read from a file by calling :func:`read_authorized_keys`. An authorized keys file may contain public keys or X.509 certificates in OpenSSH format or X.509 certificate subject names. See :ref:`SpecifyingX509Subjects` for more information on using subject names in place of specific X.509 certificates. SSHAuthorizedKeys ----------------- .. autoclass:: SSHAuthorizedKeys() import_authorized_keys ---------------------- .. autofunction:: import_authorized_keys read_authorized_keys -------------------- .. autofunction:: read_authorized_keys .. index:: Exceptions .. _Exceptions: Exceptions ========== PasswordChangeRequired ---------------------- .. autoexception:: PasswordChangeRequired BreakReceived ------------- .. autoexception:: BreakReceived SignalReceived -------------- .. autoexception:: SignalReceived TerminalSizeChanged ------------------- .. autoexception:: TerminalSizeChanged DisconnectError --------------- .. autoexception:: DisconnectError ChannelOpenError ---------------- .. autoexception:: ChannelOpenError ProcessError ------------ .. autoexception:: ProcessError SFTPError --------- .. autoexception:: SFTPError KeyImportError -------------- .. autoexception:: KeyImportError KeyExportError -------------- .. autoexception:: KeyExportError KeyEncryptionError ------------------ .. autoexception:: KeyEncryptionError KeyGenerationError ------------------ .. autoexception:: KeyGenerationError .. index:: Supported algorithms .. _SupportedAlgorithms: Supported Algorithms ==================== .. index:: Key exchange algorithms .. _KexAlgs: Key exchange algorithms ----------------------- The following are the key exchange algorithms currently supported by AsyncSSH: | gss-gex-sha256 | gss-gex-sha1 | gss-group1-sha1 | gss-group14-sha1 | gss-group14-sha256 | gss-group15-sha512 | gss-group16-sha512 | gss-group17-sha512 | gss-group18-sha512 | curve25519-sha256 | curve25519-sha256\@libssh.org | ecdh-sha2-nistp521 | ecdh-sha2-nistp384 | ecdh-sha2-nistp256 | diffie-hellman-group-exchange-sha256 | diffie-hellman-group-exchange-sha1 | diffie-hellman-group1-sha1 | diffie-hellman-group14-sha1 | diffie-hellman-group14-sha256 | diffie-hellman-group15-sha512 | diffie-hellman-group16-sha512 | diffie-hellman-group17-sha512 | diffie-hellman-group18-sha512 Curve25519 support is only available when the libnacl package and libsodium library are installed. GSS authentication support is only available when the gssapi package is installed on UNIX or the pypiwin32 package is installed on Windows. .. index:: Encryption algorithms .. _EncryptionAlgs: Encryption algorithms --------------------- The following are the encryption algorithms currently supported by AsyncSSH: | chacha20-poly1305\@openssh.com | aes256-ctr | aes192-ctr | aes128-ctr | aes256-gcm\@openssh.com | aes128-gcm\@openssh.com | aes256-cbc | aes192-cbc | aes128-cbc | 3des-cbc | blowfish-cbc | cast128-cbc | arcfour256 | arcfour128 | arcfour Chacha20-poly1305 support is only available when the libnacl package and libsodium library are installed. .. index:: MAC algorithms .. _MACAlgs: MAC algorithms -------------- The following are the MAC algorithms currently supported by AsyncSSH: | umac-64-etm\@openssh.com | umac-128-etm\@openssh.com | hmac-sha2-256-etm\@openssh.com | hmac-sha2-512-etm\@openssh.com | hmac-sha1-etm\@openssh.com | hmac-md5-etm\@openssh.com | hmac-sha2-256-96-etm\@openssh.com | hmac-sha2-512-96-etm\@openssh.com | hmac-sha1-96-etm\@openssh.com | hmac-md5-96-etm\@openssh.com | umac-64\@openssh.com | umac-128\@openssh.com | hmac-sha2-256 | hmac-sha2-512 | hmac-sha1 | hmac-md5 | hmac-sha2-256-96 | hmac-sha2-512-96 | hmac-sha1-96 | hmac-md5-96 UMAC support is only available when the nettle library is installed. .. index:: Compression algorithms .. _CompressionAlgs: Compression algorithms ---------------------- The following are the compression algorithms currently supported by AsyncSSH: | zlib\@openssh.com | zlib | none .. index:: Signature algorithms .. _SignatureAlgs: Signature algorithms -------------------- The following are the public key signature algorithms currently supported by AsyncSSH: | x509v3-ecdsa-sha2-nistp521 | x509v3-ecdsa-sha2-nistp384 | x509v3-ecdsa-sha2-nistp256 | x509v3-rsa2048-sha256 | x509v3-ssh-rsa | x509v3-ssh-dss | ssh-ed25519 | ecdsa-sha2-nistp521 | ecdsa-sha2-nistp384 | ecdsa-sha2-nistp256 | rsa-sha2-256 | rsa-sha2-512 | ssh-rsa | ssh-dss .. index:: Public key & certificate algorithms .. _PublicKeyAlgs: Public key & certificate algorithms ----------------------------------- The following are the public key and certificate algorithms currently supported by AsyncSSH: | x509v3-ecdsa-sha2-nistp521 | x509v3-ecdsa-sha2-nistp384 | x509v3-ecdsa-sha2-nistp256 | x509v3-rsa2048-sha256 | x509v3-ssh-rsa | x509v3-ssh-dss | ssh-ed25519-cert-v01\@openssh.com | ecdsa-sha2-nistp521-cert-v01\@openssh.com | ecdsa-sha2-nistp384-cert-v01\@openssh.com | ecdsa-sha2-nistp256-cert-v01\@openssh.com | ssh-rsa-cert-v01\@openssh.com | ssh-dss-cert-v01\@openssh.com | ssh-ed25519 | ecdsa-sha2-nistp521 | ecdsa-sha2-nistp384 | ecdsa-sha2-nistp256 | rsa-sha2-256 | rsa-sha2-512 | ssh-rsa | ssh-dss Ed25519 support is only available when the libnacl package and libsodium library are installed. .. index:: Constants .. _Constants: Constants ========= .. index:: Disconnect reasons .. _DisconnectReasons: Disconnect reasons ------------------ The following values defined in section 11.1 of :rfc:`4253#section-11.1` can be specified as disconnect reason codes: | DISC_HOST_NOT_ALLOWED_TO_CONNECT | DISC_PROTOCOL_ERROR | DISC_KEY_EXCHANGE_FAILED | DISC_RESERVED | DISC_MAC_ERROR | DISC_COMPRESSION_ERROR | DISC_SERVICE_NOT_AVAILABLE | DISC_PROTOCOL_VERSION_NOT_SUPPORTED | DISC_HOST_KEY_NOT_VERIFYABLE | DISC_CONNECTION_LOST | DISC_BY_APPLICATION | DISC_TOO_MANY_CONNECTIONS | DISC_AUTH_CANCELLED_BY_USER | DISC_NO_MORE_AUTH_METHODS_AVAILABLE | DISC_ILLEGAL_USER_NAME .. index:: Channel open failure reasons .. _ChannelOpenFailureReasons: Channel open failure reasons ---------------------------- The following values defined in section 5.1 of :rfc:`4254#section-5.1` can be specified as channel open failure reason codes: | OPEN_ADMINISTRATIVELY_PROHIBITED | OPEN_CONNECT_FAILED | OPEN_UNKNOWN_CHANNEL_TYPE | OPEN_RESOURCE_SHORTAGE In addition, AsyncSSH defines the following channel open failure reason codes: | OPEN_REQUEST_X11_FORWARDING_FAILED | OPEN_REQUEST_PTY_FAILED | OPEN_REQUEST_SESSION_FAILED .. index:: SFTP error codes .. _SFTPErrorCodes: SFTP error codes ---------------- The following values defined in the `SSH File Transfer Internet Draft `_ can be specified as SFTP error codes: | FX_OK | FX_EOF | FX_NO_SUCH_FILE | FX_PERMISSION_DENIED | FX_FAILURE | FX_BAD_MESSAGE | FX_NO_CONNECTION | FX_CONNECTION_LOST | FX_OP_UNSUPPORTED .. index:: Extended data types .. _ExtendedDataTypes: Extended data types ------------------- The following values defined in section 5.2 of :rfc:`4254#section-5.2` can be specified as SSH extended channel data types: | EXTENDED_DATA_STDERR .. index:: POSIX terminal modes .. _PTYModes: POSIX terminal modes -------------------- The following values defined in section 8 of :rfc:`4254#section-8` can be specified as PTY mode opcodes: | PTY_OP_END | PTY_VINTR | PTY_VQUIT | PTY_VERASE | PTY_VKILL | PTY_VEOF | PTY_VEOL | PTY_VEOL2 | PTY_VSTART | PTY_VSTOP | PTY_VSUSP | PTY_VDSUSP | PTY_VREPRINT | PTY_WERASE | PTY_VLNEXT | PTY_VFLUSH | PTY_VSWTCH | PTY_VSTATUS | PTY_VDISCARD | PTY_IGNPAR | PTY_PARMRK | PTY_INPCK | PTY_ISTRIP | PTY_INLCR | PTY_IGNCR | PTY_ICRNL | PTY_IUCLC | PTY_IXON | PTY_IXANY | PTY_IXOFF | PTY_IMAXBEL | PTY_ISIG | PTY_ICANON | PTY_XCASE | PTY_ECHO | PTY_ECHOE | PTY_ECHOK | PTY_ECHONL | PTY_NOFLSH | PTY_TOSTOP | PTY_IEXTEN | PTY_ECHOCTL | PTY_ECHOKE | PTY_PENDIN | PTY_OPOST | PTY_OLCUC | PTY_ONLCR | PTY_OCRNL | PTY_ONOCR | PTY_ONLRET | PTY_CS7 | PTY_CS8 | PTY_PARENB | PTY_PARODD | PTY_OP_ISPEED | PTY_OP_OSPEED asyncssh-1.11.1/docs/changes.rst000066400000000000000000001214731320320510200165140ustar00rootroot00000000000000.. currentmodule:: asyncssh Change Log ========== Release 1.11.1 (15 Nov 2017) ---------------------------- * Switched to using PBKDF2 implementation provided by PyCA, replacing a much slower pure-Python implementation used in earlier releases. * Improved support for file-like objects in process I/O redirection, properly handling objects which don't support fileno() and allowing both text and binary file objects based on whether they have an 'encoding' member. * Changed PEM parser to be forgiving of trailing blank lines. * Updated documentation to note lack of support in OpenSSH for send_signal(), terminate(), and kill() channel requests. * Updated unit tests to work better with OpenSSH 7.6. * Updated Travis CI config to test with more recent Python versions. Release 1.11.0 (9 Sep 2017) --------------------------- * Added support for X.509 certificate based client and server authentication, as defined in RFC 6187. * DSA, RSA, and ECDSA keys are supported. * New methods are available on SSHKey private keys to generate X.509 user, host, and CA certificates. * Authorized key and known host support has been enhanced to support matching on X.509 certificates and X.509 subject names. * New arguments have been added to create_connection() and create_server() to specify X.509 trusted root CAs, X.509 trusted root CA hash directories, and allowed X.509 certificate purposes. * A new load_certificates() function has been added to more easily pre-load a list of certificates from byte strings or files. * Support for including and validating OCSP responses is not yet available, but may be added in a future release. * This support adds a new optional dependency on pyOpenSSL in setup.py. * Added command, subsystem, and environment properties to SSHProcess, SSHCompletedProcess, and ProcessError classes, as well as stdout and stderr properties in ProcessError which mirror what is already present in SSHCompletedProcess. Thanks go to iforapsy for suggesting this. * Worked around a datetime.max bug on Windows. * Increased the build timeout on TravisCI to avoid build failures. Release 1.10.1 (19 May 2017) ---------------------------- * Fixed SCP to properly call exit() on SFTPServer when the copy completes. Thanks go to Arthur Darcet for discovering this and providing a suggested fix. * Added support for passphrase to be specified when loading default client keys, and to ignore encrypted default keys if no passphrase is specified. * Added additional known hosts test cases. Thanks go to Rafael Viotti for providing these. * Increased the default number of rounds for OpenSSH-compatible bcrypt private key encryption to avoid a warning in the latest version of the bcrypt module, and added a note that the encryption strength scale linearly with the rounds value, not logarithmically. * Fixed SCP unit test errors on Windows. * Fixed some issues with Travis and Appveyor CI builds. Release 1.10.0 (5 May 2017) --------------------------- * Added SCP client and server support, The new asyncssh.scp() function can get and put files on a remote SCP server and copy files between two or more remote SCP servers, with options similar to what was previously supported for SFTP. On the server side, an SFTPServer used to serve files over SFTP can also serve files over SCP by simply setting allow_scp to True in the call to create_server(). * Added a new SSHServerProcess class which supports I/O redirection on inbound connections to an SSH server, mirroring the SSHClientProcess class added previously for outbound SSH client connections. * Enabled TCP keepalive on SSH client and server connections. * Enabled Python 3 highlighting in Sphinx documentation. * Fixed a bug where a previously loaded SSHKnownHosts object wasn't properly accepted as a known_hosts value in create_connection() and enhanced known_hosts to accept a callable to allow applications to provide their own function to return trusted host keys. * Fixed a bug where an exception was raised if the connection closed while waiting for an asynchronous authentication callback to complete. * Fixed a bug where empty passwords weren't being properly supported. Release 1.9.0 (18 Feb 2017) --------------------------- * Added support for GSSAPI key exchange and authentication when the "gssapi" module is installed on UNIX or the "sspi" module from pypiwin32 is installed on Windows. * Added support for additional Diffie Hellman groups, and added the ability for Diffie Hellman and GSS group exchange to select larger group sizes. * Added overridable methods format_user() and format_group() to format user and group names in the SFTP server, defaulting to the previous behavior of using pwd.getpwuid() and grp.getgrgid() on platforms that support those. * Added an optional progress reporting callback on SFTP file transfers, and made the block size for these transfers configurable. * Added append_private_key(), append_public_key(), and append_certificate() methods on the corresponding key and certificate classes to simplify the creating of files containing a list of keys/certificates. * Updated readdir to break responses into chunks to avoid hitting maximum message size limits on large directories. * Updated SFTP to work better on Windows, properly handling drive letters and conversion between forward and back slashes in paths and handling setting of attributes on open files and proper support for POSIX rename. Also, file closes now block until the close completes, to avoid issues with file locking. * Updated the unit tests to run on Windows, and enabled continuous integration builds for Windows to automatically run on Appveyor. Release 1.8.1 (29 Dec 2016) --------------------------- * Fix an issue in attempting to load the 'nettle' library on Windows. Release 1.8.0 (29 Dec 2016) --------------------------- * Added support for forwarding X11 connections. When requested, AsyncSSH clients will allow remote X11 applications to tunnel data back to a local X server and AsyncSSH servers can request an X11 DISPLAY value to export to X11 applications they launch which will tunnel data back to an X server associated with the client. * Improved ssh-agent forwarding support on UNIX to allow AsyncSSH servers to request an SSH_AUTH_SOCK value to export to applications they launch in order to access the client's ssh-agent. Previously, there was support for agent forwarding on server connections within AsyncSSH itself, but they did not provide this forwarding to other applications. * Added support for PuTTY's Pageant agent on Windows systems, providing functionality similar to the OpenSSH agent on UNIX. AsyncSSH client connections from Windows can now access keys stored in the Pageant agent when they perform public key authentication. * Added support for the umac-64 and umac-128 MAC algorithms, compatible with the implementation in OpenSSH. These algorithms are preferred over the HMAC algorithms when both are available and the cipher chosen doesn't already include a MAC. * Added curve25519-sha256 as a supported key exchange algorithm. This algorithm is identical to the previously supported algorithm named 'curve25519-sha256\@libssh.org', matching what was done in OpenSSH 7.3. Either name may now be used to request this type of key exchange. * Changed the default order of key exchange algorithms to prefer the curve25519-sha256 algorithm over the ecdh-sha2-nistp algorithms. * Added support for a readuntil() function in SSHReader, modeled after the readuntil() function in asyncio.StreamReader added in Python 3.5.2. Thanks go to wwjiang for suggesting this and providing an example implementation. * Fixed issues where the explicitly provided event loop value was not being passed through to all of the places which needed it. Thanks go to Vladimir Rutsky for pointing out this problem and providing some initial fixes. * Improved error handling when port forwarding is requested for a port number outside of the range 0-65535. * Disabled use of IPv6 in unit tests when opening local loopback sockets to avoid issues with incomplete IPv6 support in TravisCI. * Changed the unit tests to always start with a known set of environment variables rather than inheriting the environment from the shell running the tests. This was leading to test breakage in some cases. Release 1.7.3 (22 Nov 2016) --------------------------- * Updated unit tests to run properly in environments where OpenSSH and OpenSSL are not installed. * Updated a process unit test to not depend on the system's default file encoding being UTF-8. * Updated Mac TravisCI builds to use Xcode 8.1. * Cleaned up some wording in the documentation. Release 1.7.2 (28 Oct 2016) --------------------------- * Fixed an issue with preserving file access times in SFTP, and update the unit tests to more accurate detect this kind of failure. * Fixed some markup errors in the documentation. * Fixed a small error in the change log for release 1.7.0 regarding the newly added Diffie Hellman key exchange algorithms. Release 1.7.1 (7 Oct 2016) -------------------------- * Fix an error that prevented the docs from building. Release 1.7.0 (7 Oct 2016) -------------------------- * Added support for group 14, 16, and 18 Diffie Hellman key exchange algorithms which use SHA-256 and SHA-512. * Added support for using SHA-256 and SHA-512 based signature algorithms for RSA keys and support for OpenSSH extension negotiation to advertise these signature algorithms. * Added new load_keypairs and load_public_keys API functions which support expicitly loading keys using the same syntax that was previously available for specifying client_keys, authorized_client_keys, and server_host_keys arguments when creating SSH clients and servers. * Enhanced the SSH agent client to support adding and removing keys and certificates (including support for constraints) and locking and unlocking the agent. Support has also been added for adding and removing smart card keys in the agent. * Added support for getting and setting a comment value when generating keys and certificates, and decoding and encoding this comment when importing and exporting keys that support it. Currently, this is available for OpenSSH format private keys and OpenSSH and RFC 4716 format public keys. These comment values are also passed on to the SSH agent when keys are added to it. * Fixed a bug in the generation of ECDSA certificates that showed up when trying to use the nistp384 or nistp521 curves. * Updated unit tests to use the new key and certificate generation functions, eliminating the dependency on the ssh-keygen program. * Updated unit tests to use the new SSH agent support when adding keys to the SSH agent, eliminating the dependency on the ssh-add program. * Incorporated a fix from Vincent Bernat for an issue with launching ssh-agent on some systems during unit testing. * Fixed some typos in the documentation found by Jakub Wilk. Release 1.6.2 (4 Sep 2016) -------------------------- * Added generate_user_certificate() and generate_host_certificate() methods to SSHKey class to generate SSH certificates, and export_certificate() and write_certificate() methods on SSHCertificate class to export certificates for use in other tools. * Improved editor unit tests to eliminate timing dependency. * Cleaned up a few minor documentation issues. Release 1.6.1 (27 Aug 2016) --------------------------- * Added generate_private_key() function to create new DSA, RSA, ECDSA, or Ed25519 private keys which can be used as SSH user and host keys. * Removed an unintended dependency in the SSHLineEditor on session objects keep a private member which referenced the corresponding channel. * Fixed a race condition in SFTP unit tests. * Updated dependencies to require version 1.5 of the cryptography module and started to take advantage of the new one-shot sign and verify APIs it now supports. * Clarified the documentation of the default return value of eof_received(). * Added new multi-user client and server examples, showing a single process opening multiple SSH connections in parallel. * Updated development status and Python versions listed in setup.py. Release 1.6.0 (13 Aug 2016) --------------------------- * Added new create_process() and run() APIs modeled after the "subprocess" module to simplify redirection of stdin, stdout, and stderr and collection of output from remote SSH processes. * Added input line editing and echoing capabilities to better support interactive SSH server applications. AsyncSSH server sessions will now automatically perform input echoing and provide basic line editing capabilities to clients which request a pseudo-terminal, avoiding the need for applications to provide this functionality. * Added the ability to use SSHReader objects as async iterators in Python 3.5, returning input a line at a time. * Added support for the IUTF8 terminal mode now recognized by OpenSSH 7.3. * Fixed a bug where an SSHReader read() call could return an empty string when it followed a call to readline() instead of blocking until more input was available. * Updated AsyncSSH to use the bcrypt package from PyCA, now that it has support for the kdf function. * Updated the documentation and examples to show how to take advantage of the new features listed here. Release 1.5.6 (18 Jun 2016) --------------------------- * Added support for Python 3.5 asynchronous context managers in SSHConnection, SFTPClient, and SFTPFile, while still maintaining backward compatibility with older Python 3.4 syntax. * Updated bcrypt check in test code to only test features that depend on it when the right version is available. * Switched testing over to using tox to better support testing on multiple versions of Python. * Added tests of new Python 3.5 async syntax. * Expanded Travis CI coverage to test both Python 3.4 and 3.5 on MacOS. * Updated documentation and examples to use Python 3.5 syntax. Release 1.5.5 (11 Jun 2016) --------------------------- * Updated public_key module to make sure the right version of bcrypt is installed before attempting to use it. * Updated forward and sftp module unit tests to work better on Linux. * Changed README links to point at new readthedocs.io domain. Release 1.5.4 (6 Jun 2016) -------------------------- * Added support for setting custom SSH client and server version strings. * Added unit tests for the sftp module, bringing AsyncSSH up to 100% code coverage under test on all modules. * Added new wait_closed() method in SFTPClient class to wait for an SFTP client session to be fully closed. * Fixed an issue with error handling in new parallel SFTP file copy code. * Fixed some other minor issues in SFTP found during unit tests. * Fixed some minor documentation issues. Release 1.5.3 (2 Apr 2016) -------------------------- * Added support for opening tunneled SSH connections, where an SSH connection is opened over another SSH connection's direct TCP/IP channel. * Improve performance of SFTP over high latency connections by having the internal copy method issue multiple read requests in parallel. * Reworked SFTP to mark all coroutine functions explicitly, to provide better compatibility with the new Python 3.5 "await" syntax. * Reworked create_connection() and create_server() functions to do argument checking immediately rather than in the SSHConnection constructors, improving error reporting and avoiding a bug in asyncio which can leak socket objects. * Fixed a hang which could occur when attempting to close an SSH connection with a listener still active. * Fixed an error related to passing keys in via public_key_auth_requested(). * Fixed a potential leak of an SSHAgentClient object when an error occurs while opening a client connection. * Fixed some race conditions related to channel and connection closes. * Fixed some minor documentation issues. * Continued to expand unit test coverage, completing coverage of the connection module. Release 1.5.2 (25 Feb 2016) --------------------------- * Fixed a bug in UNIX domain socket forwarding introduced in 1.5.1 by the TCP_NODELAY change. * Fixed channel code to report when a channel is closed with incomplete Unicode data in the receive buffer. This was previously reported correctly when EOF was received on a channel, but not when it was closed without sending EOF. * Added unit tests for channel, forward, and stream modules, partial unit tests for the connection module, and a placeholder for unit tests for the sftp module. Release 1.5.1 (23 Feb 2016) --------------------------- * Added basic support for running AsyncSSH on Windows. Some functionality such as UNIX domain sockets will not work there, and the test suite will not run there yet, but basic functionality has been tested and seems to work. This includes features like bcrypt and support for newer ciphers provided by libnacl when these optional packages are installed. * Greatly improved the performance of known_hosts matching on exact hostnames and addresses. Full wildcard pattern matching is still supported, but entries involving exact hostnames or addresses are now matched thousands of times faster. * Split known_hosts parsing and matching into separate calls so that a known_hosts file can be parsed once and used to make connections to several different hosts. Thanks go to Josh Yudaken for suggesting this and providing a sample implementation. * Updated AsyncSSH to allow SSH agent forwarding when it is requested even when local client keys are used to perform SSH authentication. * Updaded channel state machine to better handle close being received while the channel is paused for reading. Previously, some data would not be delivered in this case. * Set TCP_NODELAY on sockets to avoid latency problems caused by TCP delayed ACK. * Fixed a bug where exceptions were not always returned properly when attempting to drain writes on a stream. * Fixed a bug which could leak a socket object after an error opening a local TCP listening socket. * Fixed a number of race conditions uncovered during unit testing. Release 1.5.0 (27 Jan 2016) --------------------------- * Added support for OpenSSH-compatible direct and forwarded UNIX domain socket channels and local and remote UNIX domain socket forwarding. * Added support for client and server side ssh-agent forwarding. * Fixed the open_connection() method on SSHServerConnection to not include a handler_factory argument. This should only have been present on the start_server() method. * Fixed wait_closed() on SSHForwardListener to work properly when a close is in progress at the time of the call. Release 1.4.1 (23 Jan 2016) --------------------------- * Fixed a bug in SFTP introduced in 1.4.0 related to handling of responses to non-blocking file closes. * Updated code to avoid calling asyncio.async(), deprecated in Python 3.4.4. * Updated unit tests to avoid errors on systems with an older version of OpenSSL installed. Release 1.4.0 (17 Jan 2016) --------------------------- * Added ssh-agent client support, automatically using it when SSH_AUTH_SOCK is set and client private keys aren't explicitly provided. * Added new wait_closed() API on SSHConnection to allow applications to wait for a connection to be fully closed and updated examples to use it. * Added a new login_timeout argument when create an SSH server. * Added a missing acknowledgement response when canceling port forwarding and fixed a few other issues related to cleaning up port forwarding listeners. * Added handlers to improve the catching and reporting of exceptions that are raised in asynchronous tasks. * Reworked channel state machine to perform clean up on a channel only after a close is both sent and received. * Fixed SSHChannel to run the connection_lost() handler on the SSHSession before unblocking callers of wait_closed(). * Fixed wait_closed() on SSHListener to wait for the acknowledgement from the SSH server before returning. * Fixed a race condition in port forwarding code. * Fixed a bug related to sending a close on a channel which got a failure when being opened. * Fixed a bug related to handling term_type being set without term_size. * Fixed some issues related to the automatic conversion of client keyboard-interactive auth to password auth. With this change, automatic conversion will only occur if the application doesn't override the kbdint_challenge_received() method and it will only attempt to authenticate once with the password provided. Release 1.3.2 (26 Nov 2015) --------------------------- * Added server-side support for handling password changes during password authentication, and fixed a few other auth-related bugs. * Added the ability to override the automatic support for keyboard-interactive authentication when password authentication is supported. * Fixed a race condition in unblocking streams. * Removed support for OpenSSH v00 certificates now that OpenSSH no longer supports them. * Added unit tests for auth module. Release 1.3.1 (6 Nov 2015) -------------------------- * Updated AsyncSSH to depend on version 1.1 or later of PyCA and added support for using its new Elliptic Curve Diffie Hellman (ECDH) implementation, replacing the previous AsyncSSH native Python version. * Added support for specifying a passphrase in the create_connection, create_server, connect, and listen functions to allow file names or byte strings containing encrypted client and server host keys to be specified in those calls. * Fixed handling of cancellation in a few AsyncSSH calls, so it is now possible to make calls to things like stream read or drain which time out. * Fixed a bug in keyboard-interactive fallback to password auth which was introduced when support was added for auth functions optionally being coroutines. * Move bcrypt check in encrypted key handling until it is needed so better errors can be returned if a passphrase is not specified or the key derivation function used in a key is unknown. * Added unit tests for the auth_keys module. * Updated unit tests to better handle bcrypt or libnacl not being installed. Release 1.3.0 (10 Oct 2015) --------------------------- * Updated AsyncSSH dependencies to make PyCA version 1.0.0 or later mandatory and remove the older PyCrypto support. This change also adds support for the PyCA implementation of ECDSA and removes support for RC2-based private key encryption that was only supported by PyCrypto. * Refactored ECDH and Curve25519 key exchange code so they can share an implementation, and prepared the code for adding a PyCA shim for this as soon as support for that is released. * Hardened the DSA and RSA implementations to do stricter checking of the key exchange response, and sped up the RSA implementation by taking advantage of optional RSA private key parameters when they are present. * Added support for asynchronous client and server authentication, allowing auth-related callbacks in SSHClient and SSHServer to optionally be defined as coroutines. * Added support for asynchronous SFTP server processing, allowing callbacks in SFTPServer to optionally be defined as coroutines. * Added support for a broader set of open mode flags in the SFTP server. Note that this change is not completely backward compatible with previous releases. If you have application code which expects a Python mode string as an argument to SFTPServer open method, it will need to be changed to expect a pflags value instead. * Fixed handling of eof_received() when it returns false to close the half-open connection but still allow sending or receiving of exit status and exit signals. * Added unit tests for the asn1, cipher, compression, ec, kex, known_hosts, mac, and saslprep modules and expended the set of pbe and public_key unit tests. * Fixed a set of issues uncovered by ASN.1 unit tests: * Removed extra 0xff byte when encoding integers of the form -128*256^n * Fixed decoding error for OIDs beginning with 2.n where n >= 40 * Fixed range check for second component of ObjectIdentifier * Added check for extraneous 0x80 bytes in ObjectIdentifier components * Added check for negative component values in ObjectIdentifier * Added error handling for ObjectIdentifier components being non-integer * Added handling for missing length byte after extended tag * Raised ASN1EncodeError instead of TypeError on unsupported types * Added validation on asn1_class argument, and equality and hash methods to BitString, RawDERObject, and TaggedDERObject. Also, reordered RawDERObject arguments to be consistent with TaggedDERObject and added str method to ObjectIdentifier. * Fixed a set of issues uncovered by additional pbe unit tests: * Encoding and decoding of PBES2-encrypted keys with a PRF other than SHA1 is now handled correctly. * Some exception messages were made more specific. * Additional checks were put in for empty salt or zero iteration count in encryption parameters. * Fixed a set of issues uncovered by additional public key unit tests: * Properly handle PKCS#8 keys with invalid ASN.1 data * Properly handle PKCS#8 DSA & RSA keys with non-sequence for arg_params * Properly handle attempts to import empty string as a public key * Properly handle encrypted PEM keys with missing DEK-Info header * Report check byte mismatches for encrypted OpenSSH keys as bad passphrase * Return KeyImportError instead of KeyEncryptionError when passphrase is needed but not provided * Added information about branches to CONTRIBUTING guide. * Performed a bunch of code cleanup suggested by pylint. Release 1.2.1 (26 Aug 2015) --------------------------- * Fixed a problem with passing in client_keys=None to disable public key authentication in the SSH client. * Updated Unicode handling to allow multi-byte Unicode characters to be split across successive SSH data messages. * Added a note to the documentation for AsyncSSH create_connection() explaining how to perform the equivalent of a connect with a timeout. Release 1.2.0 (6 Jun 2015) -------------------------- * Fixed a problem with the SSHConnection context manager on Python versions older than 3.4.2. * Updated the documentation for get_extra_info() in the SSHConnection, SSHChannel, SSHReader, and SSHWriter classes to contain pointers to get_extra_info() in their parent transports to make it easier to see all of the attributes which can be queried. * Clarified the legal return values for the session_requested(), connection_requested(), and server_requested() methods in SSHServer. * Eliminated calls to the deprecated importlib.find_loader() method. * Made improvements to README suggested by Nicholas Chammas. * Fixed a number of issues identified by pylint. Release 1.1.1 (25 May 2015) --------------------------- * Added new start_sftp_server method on SSHChannel to allow applications using the non-streams API to start an SFTP server. * Enhanced the default format_longname() method in SFTPServer to properly handle the case where not all of the file attributes are returned by stat(). * Fixed a bug related to the new allow_pty parameter in create_server. * Fixed a bug in the hashed known_hosts support introduced in some recent refactoring of the host pattern matching code. Release 1.1.0 (22 May 2015) --------------------------- * SFTP is now supported! * Both client and server support is available. * SFTP version 3 is supported, with OpenSSH extensions. * Recursive transfers and glob matching are supported in the client. * File I/O APIs allow files to be accessed without downloading them. * New simplified connect and listen APIs have been added. * SSHConnection can now be used as a context manager. * New arguments to create_server now allow the specification of a session_factory and encoding or sftp_factory as well as controls over whether a pty is allowed and the window and max packet size, avoiding the need to create custom SSHServer subclasses or custom SSHServerChannel instances. * New examples have been added for SFTP and to show the use of the new connect and listen APIs. * Copyrights in changed files have all been updated to 2015. Release 1.0.1 (13 Apr 2015) --------------------------- * Fixed a bug in OpenSSH private key encryption introduced in some recent cipher refactoring. * Added bcrypt and libnacl as optional dependencies in setup.py. * Changed test_keys test to work properly when bcrypt or libnacl aren't installed. Release 1.0.0 (11 Apr 2015) --------------------------- * This release finishes adding a number of major features, finally making it worthy of being called a "1.0" release. * Host and user certificates are now supported! * Enforcement is done on principals in certificates. * Enforcement is done on force-command and source-address critical options. * Enforcement is done on permit-pty and permit-port-forwarding extensions. * OpenSSH-style known hosts files are now supported! * Positive and negative wildcard and CIDR-style patterns are supported. * HMAC-SHA1 hashed host entries are supported. * The @cert-authority and @revoked markers are supported. * OpenSSH-style authorized keys files are now supported! * Both client keys and certificate authorities are supported. * Enforcement is done on from and principals options during key matching. * Enforcement is done on no-pty, no-port-forwarding, and permitopen. * The command and environment options are supported. * Applications can query for their own non-standard options. * Support has been added for OpenSSH format private keys. * DSA, RSA, and ECDSA keys in this format are now supported. * Ed25519 keys are supported when libnacl and libsodium are installed. * OpenSSH private key encryption is supported when bcrypt is installed. * Curve25519 Diffie-Hellman key exchange is now available via either the curve25519-donna or libnacl and libsodium packages. * ECDSA key support has been enhanced. * Support is now available for PKCS#8 ECDSA v2 keys. * Support is now available for both NamedCurve and explicit ECParameter versions of keys, as long as the parameters match one of the supported curves (nistp256, nistp384, or nistp521). * Support is now available for the OpenSSH chacha20-poly1305 cipher when libnacl and libsodium are installed. * Cipher names specified in private key encryption have been changed to be consistent with OpenSSH cipher naming, and all SSH ciphers can now be used for encryption of keys in OpenSSH private key format. * A couple of race conditions in SSHChannel have been fixed and channel cleanup is now delayed to allow outstanding message handling to finish. * Channel exceptions are now properly delivered in the streams API. * A bug in SSHStream read() where it could sometimes return more data than requested has been fixed. Also, read() has been changed to properly block and return all data until EOF or a signal is received when it is called with no length. * A bug in the default implementation of keyboard-interactive authentication has been fixed, and the matching of a password prompt has been loosened to allow it to be used for password authentication on more devices. * Missing code to resume reading after a stream is paused has been added. * Improvements have been made in the handling of canceled requests. * The test code has been updated to test Ed25519 and OpenSSH format private keys. * Examples have been updated to reflect some of the new capabilities. Release 0.9.2 (26 Jan 2015) --------------------------- * Fixed a bug in PyCrypto CipherFactory introduced during PyCA refactoring. Release 0.9.1 (3 Dec 2014) -------------------------- * Added some missing items in setup.py and MANIFEST.in. * Fixed the install to work even when cryptographic dependencies aren't yet installed. * Fixed an issue where get_extra_info calls could fail if called when a connection or session was shutting down. Release 0.9.0 (14 Nov 2014) --------------------------- * Added support to use PyCA (0.6.1 or later) for cryptography. AsyncSSH will automatically detect and use either PyCA, PyCrypto, or both depending on which is installed and which algorithms are requested. * Added support for AES-GCM ciphers when PyCA is installed. Release 0.8.4 (12 Sep 2014) --------------------------- * Fixed an error in the encode/decode functions for PKCS#1 DSA public keys. * Fixed a bug in the unit test code for import/export of RFC4716 public keys. Release 0.8.3 (16 Aug 2014) --------------------------- * Added a missing import in the curve25519 implementation. Release 0.8.2 (16 Aug 2014) --------------------------- * Provided a better long description for PyPI. * Added link to PyPI in documentation sidebar. Release 0.8.1 (15 Aug 2014) --------------------------- * Added a note in the :meth:`validate_public_key() ` documentation clarifying that AsyncSSH will verify that the client possesses the corresponding private key before authentication is allowed to succeed. * Switched from setuptools to distutils and added an initial set of unit tests. * Prepared the package to be uploaded to PyPI. Release 0.8.0 (15 Jul 2014) --------------------------- * Added support for Curve25519 Diffie Hellman key exchange on systems with the curve25519-donna Python package installed. * Updated the examples to more clearly show what values are returned even when not all of the return values are used. Release 0.7.0 (7 Jun 2014) -------------------------- * This release adds support for the "high-level" ``asyncio`` streams API, in the form of the :class:`SSHReader` and :class:`SSHWriter` classes and wrapper methods such as :meth:`open_session() `, :meth:`open_connection() `, and :meth:`start_server() `. It also allows the callback methods on :class:`SSHServer` to return either SSH session objects or handler functions that take :class:`SSHReader` and :class:`SSHWriter` objects as arguments. See :meth:`session_requested() `, :meth:`connection_requested() `, and :meth:`server_requested() ` for more information. * Added new exceptions :exc:`BreakReceived`, :exc:`SignalReceived`, and :exc:`TerminalSizeChanged` to report when these messages are received while trying to read from an :class:`SSHServerChannel` using the new streams API. * Changed :meth:`create_server() ` to accept either a callable or a coroutine for its ``session_factory`` argument, to allow asynchronous operations to be used when deciding whether to accept a forwarded TCP connection. * Renamed ``accept_connection()`` to :meth:`create_connection() ` in the :class:`SSHServerConnection` class for consistency with :class:`SSHClientConnection`, and added a corresponding :meth:`open_connection() ` method as part of the streams API. * Added :meth:`get_exit_status() ` and :meth:`get_exit_signal() ` methods to the :class:`SSHClientChannel` class. * Added :meth:`get_command() ` and :meth:`get_subsystem() ` methods to the :class:`SSHServerChannel` class. * Fixed the name of the :meth:`write_stderr() ` method and added the missing :meth:`writelines_stderr() ` method to the :class:`SSHServerChannel` class for outputting data to the stderr channel. * Added support for a return value in the :meth:`eof_received() ` of :class:`SSHClientSession`, :class:`SSHServerSession`, and :class:`SSHTCPSession` to support half-open channels. By default, the channel is automatically closed after :meth:`eof_received() ` returns, but returning ``True`` will now keep the channel open, allowing output to still be sent on the half-open channel. This is done automatically when the new streams API is used. * Added values ``'local_peername'`` and ``'remote_peername'`` to the set of information available from the :meth:`get_extra_info() ` method in the :class:`SSHTCPChannel` class. * Updated functions returning :exc:`IOError` or :exc:`socket.error` to return the new :exc:`OSError` exception introduced in Python 3.3. * Cleaned up some errors in the documentation. * The :ref:`API`, :ref:`ClientExamples`, and :ref:`ServerExamples` have all been updated to reflect these changes, and new examples showing the streams API have been added. Release 0.6.0 (11 May 2014) --------------------------- * This release is a major revamp of the code to migrate from the ``asyncore`` framework to the new ``asyncio`` framework in Python 3.4. All the APIs have been adapted to fit the new ``asyncio`` paradigm, using coroutines wherever possible to avoid the need for callbacks when performing asynchronous operations. So far, this release only supports the "low-level" ``asyncio`` API. * The :ref:`API`, :ref:`ClientExamples`, and :ref:`ServerExamples` have all been updated to reflect these changes. Release 0.5.0 (11 Oct 2013) --------------------------- * Added the following new classes to support fully asynchronous connection forwarding, replacing the methods previously added in release 0.2.0: * :class:`SSHClientListener` * :class:`SSHServerListener` * :class:`SSHClientLocalPortForwarder` * :class:`SSHClientRemotePortForwarder` * :class:`SSHServerPortForwarder` These new classes allow for DNS lookups and other operations to be performed fully asynchronously when new listeners are set up. As with the asynchronous connect changes below, methods are now available to report when the listener is opened or when an error occurs during the open rather than requiring the listener to be fully set up in a single call. * Updated examples in :ref:`ClientExamples` and :ref:`ServerExamples` to reflect the above changes. Release 0.4.0 (28 Sep 2013) --------------------------- * Added support in :class:`SSHTCPConnection` for the following methods to allow asynchronous operations to be used when accepting inbound connection requests: * :meth:`handle_open_request() ` * :meth:`report_open() ` * :meth:`report_open_error() ` These new methods are used to implement asynchronous connect support for local and remote port forwarding, and to support trying multiple destination addresses when connection failures occur. * Cleaned up a few minor documentation errors. Release 0.3.0 (26 Sep 2013) --------------------------- * Added support in :class:`SSHClient` and :class:`SSHServer` for setting the key exchange, encryption, MAC, and compression algorithms allowed in the SSH handshake. * Refactored the algorithm selection code to pull a common matching function back into ``_SSHConnection`` and simplify other modules. * Extended the listener class to open multiple listening sockets when necessary, fixing a bug where sockets opened to listen on ``localhost`` were not properly accepting both IPv4 and IPv6 connections. Now, any listen request which resolves to multiple addresses will open listening sockets for each address. * Fixed a bug related to tracking of listeners opened on dynamic ports. Release 0.2.0 (21 Sep 2013) --------------------------- * Added support in :class:`SSHClient` for the following methods related to performing standard SSH port forwarding: * :meth:`forward_local_port() ` * :meth:`cancel_local_port_forwarding() ` * :meth:`forward_remote_port() ` * :meth:`cancel_remote_port_forwarding() ` * :meth:`handle_remote_port_forwarding() ` * :meth:`handle_remote_port_forwarding_error() ` * Added support in :class:`SSHServer` for new return values in :meth:`handle_direct_connection() ` and :meth:`handle_listen() ` to activate standard SSH server-side port forwarding. * Added a client_addr argument and member variable to :class:`SSHServer` to hold the client's address information. * Added and updated examples related to port forwarding and using :class:`SSHTCPConnection` to open direct and forwarded TCP connections in :ref:`ClientExamples` and :ref:`ServerExamples`. * Cleaned up some of the other documentation. * Removed a debug print statement accidentally left in related to SSH rekeying. Release 0.1.0 (14 Sep 2013) --------------------------- * Initial release asyncssh-1.11.1/docs/conf.py000066400000000000000000000173511320320510200156500ustar00rootroot00000000000000#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # AsyncSSH documentation build configuration file, created by # sphinx-quickstart on Sun Sep 1 17:36:31 2013. # # This file is execfile()d with the current directory set to its containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys, os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath('..')) from asyncssh import __author__, __version__ # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix of source filenames. source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = 'AsyncSSH' copyright = '2013-2017, ' + __author__ # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The full version, including alpha/beta/rc tags. release = __version__ # The short X.Y.Z version. version = '.'.join(release.split('.')[:3]) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. #language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ['_build'] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' highlight_language = 'python3' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'rftheme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { "sidebarwidth": 305, "stickysidebar": "true" } # Add any paths that contain custom themes here, relative to this directory. html_theme_path = ['.'] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. html_sidebars = {'**': ['sidebartop.html', 'localtoc.html', 'sidebarbottom.html']} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. html_domain_indices = False # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. html_show_sourcelink = False # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = 'AsyncSSHdoc' # -- Options for LaTeX output -------------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} asyncssh-1.11.1/docs/contributing.rst000066400000000000000000000000411320320510200175760ustar00rootroot00000000000000.. include:: ../CONTRIBUTING.rst asyncssh-1.11.1/docs/index.rst000066400000000000000000000504631320320510200162130ustar00rootroot00000000000000.. toctree:: :hidden: changes contributing api .. currentmodule:: asyncssh .. include:: ../README.rst .. _ClientExamples: Client Examples =============== Simple client ------------- The following code shows an example of a simple SSH client which logs into localhost and lists files in a directory named 'abc' under the user's home directory. The username provided is the logged in user, and the user's default SSH client keys or certificates are presented during authentication. The server's host key is checked against the user's SSH known_hosts file and the connection will fail if there's no entry for localhost there or if the key doesn't match. .. include:: ../examples/simple_client.py :literal: :start-line: 14 This example only uses the output on stdout, but output on stderr is also collected as another attribute in the returned :class:`SSHCompletedProcess` object. To check against a different set of server host keys, they can be provided in the known_hosts argument when the connection is opened: .. code:: async with asyncssh.connect('localhost', known_hosts='my_known_hosts') as conn: Server host key checking can be disabled by setting the known_hosts argument to ``None``, but that's not recommended as it makes the connection vulnerable to a man-in-the-middle attack. To log in as a different remote user, the username argument can be provided: .. code:: async with asyncssh.connect('localhost', username='user123') as conn: To use a different set of client keys for authentication, they can be provided in the client_keys argument: .. code:: async with asyncssh.connect('localhost', client_keys=['my_ssh_key']) as conn: Password authentication can be used by providing a password argument: .. code:: async with asyncssh.connect('localhost', password='secretpw') as conn: Any of the arguments above can be combined together as needed. If client keys and a password are both provided, either may be used depending on what forms of authentication the server supports and whether the authentication with them is successful. Callback example ---------------- AsyncSSH also provides APIs that use callbacks rather than "await" and "async with". Here's the example above written using custom :class:`SSHClient` and :class:`SSHClientSession` subclasses: .. include:: ../examples/callback_client.py :literal: :start-line: 14 In cases where you don't need to customize callbacks on the SSHClient class, this code can be simplified somewhat to: .. include:: ../examples/callback_client2.py :literal: :start-line: 14 If you need to distinguish output going to stdout vs. stderr, that's easy to do with the following change: .. include:: ../examples/callback_client3.py :literal: :start-line: 14 Interactive input ----------------- The following example demonstrates sending interactive input to a remote process. It executes the calculator program ``bc`` and performs some basic math calculations. Note that it uses the :meth:`create_process ` method rather than the :meth:`run ` method. This starts the process but doesn't wait for it to exit, allowing interaction with it. .. include:: ../examples/math_client.py :literal: :start-line: 14 When run, this program should produce the following output: .. code:: 2+2 = 4 1*2*3*4 = 24 2^32 = 4294967296 I/O redirection --------------- The following example shows how to pass a fixed input string to a remote process and redirect the resulting output to the local file '/tmp/stdout'. Input lines containing 1, 2, and 3 are passed into the 'tail -r' command and the output written to '/tmp/stdout' should contain the reversed lines 3, 2, and 1: .. include:: ../examples/redirect_input.py :literal: :start-line: 14 The ``stdin``, ``stdout``, and ``stderr`` arguments support redirecting to a variety of locations include local files, pipes, and sockets as well as :class:`SSHReader` or :class:`SSHWriter` objects associated with other remote SSH processes. Here's an example of piping stdout from a local process to a remote process: .. include:: ../examples/redirect_local_pipe.py :literal: :start-line: 14 Here's an example of piping one remote process to another: .. include:: ../examples/redirect_remote_pipe.py :literal: :start-line: 14 In this example both remote processes are running on the same SSH connection, but this redirection can just as easily be used between SSH sessions associated with connections going to different servers. Checking exit status -------------------- The following example shows how to test the exit status of a remote process: .. include:: ../examples/check_exit_status.py :literal: :start-line: 14 If an exit signal is received, the exit status will be set to -1 and exit signal information is provided in the ``exit_signal`` attribute of the returned :class:`SSHCompletedProcess`. If the ``check`` argument in :meth:`run ` is set to ``True``, any abnormal exit will raise a :exc:`ProcessError` exception instead of returning an :class:`SSHCompletedProcess`. Running multiple clients ------------------------ The following example shows how to run multiple clients in parallel and process the results when all of them have completed: .. include:: ../examples/gather_results.py :literal: :start-line: 14 Results could be processed as they became available by setting up a loop which repeatedly called :func:`asyncio.wait` instead of calling :func:`asyncio.gather`. Setting environment variables ----------------------------- The following example demonstrates setting environment variables for the remote session and displaying them by executing the 'env' command. .. include:: ../examples/set_environment.py :literal: :start-line: 14 Any number of environment variables can be passed in the dictionary given to :meth:`create_session() `. Note that SSH servers may restrict which environment variables (if any) are accepted, so this feature may require setting options on the SSH server before it will work. Setting terminal information ---------------------------- The following example demonstrates setting the terminal type and size passed to the remote session. .. include:: ../examples/set_terminal.py :literal: :start-line: 14 Note that this will cause AsyncSSH to request a pseudo-tty from the server. When a pseudo-tty is used, the server will no longer send output going to stderr with a different data type. Instead, it will be mixed with output going to stdout (unless it is redirected elsewhere by the remote command). Port forwarding --------------- The following example demonstrates the client setting up a local TCP listener on port 8080 and requesting that connections which arrive on that port be forwarded across SSH to the server and on to port 80 on ``www.google.com``: .. include:: ../examples/local_forwarding_client.py :literal: :start-line: 14 To listen on a dynamically assigned port, the client can pass in ``0`` as the listening port. If the listener is successfully opened, the selected port will be available via the :meth:`get_port() ` method on the returned listener object: .. include:: ../examples/local_forwarding_client2.py :literal: :start-line: 14 The client can also request remote port forwarding from the server. The following example shows the client requesting that the server listen on port 8080 and that connections arriving there be forwarded across SSH and on to port 80 on ``localhost``: .. include:: ../examples/remote_forwarding_client.py :literal: :start-line: 14 To limit which connections are accepted or dynamically select where to forward traffic to, the client can implement their own session factory and call :meth:`forward_connection() ` on the connections they wish to forward and raise an error on those they wish to reject: .. include:: ../examples/remote_forwarding_client2.py :literal: :start-line: 14 Just as with local listeners, the client can request remote port forwarding from a dynamic port by passing in ``0`` as the listening port and then call :meth:`get_port() ` on the returned listener to determine which port was selected. Direct TCP connections ---------------------- The client can also ask the server to open a TCP connection and directly send and receive data on it by using the :meth:`create_connection() ` method on the :class:`SSHClientConnection` object. In this example, a connection is attempted to port 80 on ``www.google.com`` and an HTTP HEAD request is sent for the document root. Note that unlike sessions created with :meth:`create_session() `, the I/O on these connections defaults to sending and receiving bytes rather than strings, allowing arbitrary binary data to be exchanged. However, this can be changed by setting the encoding to use when the connection is created. .. include:: ../examples/direct_client.py :literal: :start-line: 14 To use the streams API to open a direct connection, you can use :meth:`open_connection ` instead of :meth:`create_connection `: .. include:: ../examples/stream_direct_client.py :literal: :start-line: 14 Forwarded TCP connections ------------------------- The client can also directly process data from incoming TCP connections received on the server. The following example demonstrates the client requesting that the server listen on port 8888 and forward any received connections back to it over SSH. It then has a simple handler which echoes any data it receives back to the sender. As in the direct TCP connection example above, the default would be to send and receive bytes on this connection rather than strings, but here we set the encoding explicitly so all data is sent and received as strings: .. include:: ../examples/listening_client.py :literal: :start-line: 14 To use the streams API to open a listening connection, you can use :meth:`start_server ` instead of :meth:`create_server `: .. include:: ../examples/stream_listening_client.py :literal: :start-line: 14 SFTP client ----------- AsyncSSH also provides SFTP support. The following code shows an example of starting an SFTP client and requesting the download of a file: .. include:: ../examples/sftp_client.py :literal: :start-line: 14 To recursively download a directory, preserving access and modification times and permissions on the files, the preserve and recurse arguments can be included: .. code:: await sftp.get('example_dir', preserve=True, recurse=True) Wild card pattern matching is supported by the :meth:`mget `, :meth:`mput `, and :meth:`mcopy ` methods. The following downloads all files with extension "txt": .. code:: await sftp.mget('*.txt') See the :class:`SFTPClient` documentation for the full list of available actions. SCP client ---------- AsyncSSH also supports SCP. The following code shows an example of downloading a file via SCP: .. include:: ../examples/scp_client.py :literal: :start-line: 14 To upload a file to a remote system, host information can be specified for the destination instead of the source: .. code:: await asyncssh.scp('example.txt', 'localhost:') If the destination path includes a file name, that name will be used instead of the original file name when performing the copy. For instance: .. code:: await asyncssh.scp('example.txt', 'localhost:example2.txt') If the destination path refers to a directory, the origin file name will be preserved, but it will be copied into the requested directory. Wild card patterns are also supported on source paths. For instance, the following copies all files with extension "txt": .. code:: await asyncssh.scp('*.txt', 'localhost:') Similar to SFTP, SCP also supports options for recursively copying a directory and preserving modification times and permissions on files using the preserve and recurse arguments: .. code:: await asyncssh.scp('example_dir', 'localhost:', preserve=True, recurse=True) In addition to the ``'host:path'`` syntax for source and destination paths, a tuple of the form ``(host, path)`` is also supported. A non-default port can be specified by replacing ``host`` with ``(host, port)``, resulting in something like: .. code:: await asyncssh.scp((('localhost', 8022), 'example.txt'), '.') An already open :class:`SSHClientConnection` can also be passed as the host: .. code:: async with asyncssh.connect('localhost') as conn: await asyncssh.scp((conn, 'example.txt'), '.') Multiple file patterns can be copied to the same destination by making the source path argument a list. Source paths in this list can be a mixture of local and remote file references and the destination path can be local or remote, but one or both of source and destination must be remote. Local to local copies are not supported. See the :func:`scp` function documentation for the complete list of available options. .. _ServerExamples: Server Examples =============== Simple server ------------- The following code shows an example of a simple SSH server which listens for connections on port 8022, does password authentication, and prints a message when users authenticate successfully and start a shell. .. include:: ../examples/simple_server.py :literal: :start-line: 14 To authenticate with SSH client keys or certificates, the server would look something like the following. Client and certificate authority keys for each user need to be placed in a file matching the username in a directory called ``authorized_keys``. .. include:: ../examples/simple_keyed_server.py :literal: :start-line: 22 It is also possible to use a single authorized_keys file for all users. This is common when using certificates, as AsyncSSH can automatically enforce that the certificates presented have a principal in them which matches the username. In this case, a custom :class:`SSHServer` subclass is no longer required, and so the :func:`listen` function can be used in place of :func:`create_server`. .. include:: ../examples/simple_cert_server.py :literal: :start-line: 21 Simple server with input ------------------------ The following example demonstrates reading input in a server session. It adds a column of numbers, displaying the total when it receives EOF. .. include:: ../examples/math_server.py :literal: :start-line: 21 Callback example ---------------- Here's an example of the server above written using callbacks in custom :class:`SSHServer` and :class:`SSHServerSession` subclasses. .. include:: ../examples/callback_math_server.py :literal: :start-line: 21 I/O redirection --------------- The following shows an example of I/O redirection on the server side, executing a process on the server with input and output redirected back to the SSH client: .. include:: ../examples/redirect_server.py :literal: :start-line: 14 Serving multiple clients ------------------------ The following is a slightly more complicated example showing how a server can manage multiple simultaneous clients. It implements a basic chat service, where clients can send messages to one other. .. include:: ../examples/chat_server.py :literal: :start-line: 21 Line editing ------------ When SSH clients request a pseudo-terminal, they generally default to sending input a character at a time and expect the remote system to provide character echo and line editing. To better support interactive applications like the one above, AsyncSSH defaults to providing basic line editing for server sessions which request a pseudo-terminal. When thise line editor is enabled, it defaults to delivering input to the application a line at a time. Applications can switch between line and character at a time input using the :meth:`set_line_mode() ` method. Also, when in line mode, applications can enable or disable echoing of input using the :meth:`set_echo() ` method. The following code provides an example of this. .. include:: ../examples/editor.py :literal: :start-line: 21 Getting environment variables ----------------------------- The following example demonstrates reading environment variables set by the client. It will show all of the variables set by the client, or return an error if none are set. Note that SSH clients may restrict which environment variables (if any) are sent by default, so you may need to set options in the client to get it to do so. .. include:: ../examples/show_environment.py :literal: :start-line: 21 Getting terminal information ---------------------------- The following example demonstrates reading the client's terminal type and window size, and handling window size changes during a session. .. include:: ../examples/show_terminal.py :literal: :start-line: 21 Port forwarding --------------- The following example demonstrates a server accepting port forwarding requests from clients, but only when they are destined to port 80. When such a connection is received, a connection is attempted to the requested host and port and data is bidirectionally forwarded over SSH from the client to this destination. Requests by the client to connect to any other port are rejected. .. include:: ../examples/local_forwarding_server.py :literal: :start-line: 21 The server can also support forwarding inbound TCP connections back to the client. The following example demonstrates a server which will accept requests like this from clients, but only to listen on port 8080. When such a connection is received, the client is notified and data is bidirectionally forwarded from the incoming connection over SSH to the client. .. include:: ../examples/remote_forwarding_server.py :literal: :start-line: 21 Direct TCP connections ---------------------- The server can also accept direct TCP connection requests from the client and process the data on them itself. The following example demonstrates a server which accepts requests to port 7 (the "echo" port) for any host and echoes the data itself rather than forwarding the connection: .. include:: ../examples/direct_server.py :literal: :start-line: 21 Here's an example of this server written using the streams API. In this case, :meth:`connection_requested() ` returns a handler coroutine instead of a session object. When a new direct TCP connection is opened, the handler coroutine is called with AsyncSSH stream objects which can be used to perform I/O on the tunneled connection. .. include:: ../examples/stream_direct_server.py :literal: :start-line: 21 SFTP server ----------- The following example shows how to start an SFTP server with default behavior: .. include:: ../examples/simple_sftp_server.py :literal: :start-line: 21 A subclass of :class:`SFTPServer` can be provided as the value of the SFTP factory to override specific behavior. For example, the following code remaps path names so that each user gets access to only their own individual directory under ``/tmp/sftp``: .. include:: ../examples/chroot_sftp_server.py :literal: :start-line: 21 More complex path remapping can be performed by implementing the :meth:`map_path ` and :meth:`reverse_map_path ` methods. Individual SFTP actions can also be overridden as needed. See the :class:`SFTPServer` documentation for the full list of methods to override. SCP server ---------- The above server examples can be modified to also support SCP by simply adding ``allow_scp=True`` alongside the specification of the ``sftp_factory`` in the :func:`listen` call. This will use the same :class:`SFTPServer` instance when performing file I/O for both SFTP and SCP requests. For instance: .. include:: ../examples/simple_scp_server.py :literal: :start-line: 21 asyncssh-1.11.1/docs/rftheme/000077500000000000000000000000001320320510200157745ustar00rootroot00000000000000asyncssh-1.11.1/docs/rftheme/layout.html000066400000000000000000000002631320320510200202000ustar00rootroot00000000000000{% extends "basic/layout.html" %} {# Omit the top navigation bar. #} {% block relbar1 %} {% endblock %} {# Omit the bottom navigation bar. #} {% block relbar2 %} {% endblock %} asyncssh-1.11.1/docs/rftheme/static/000077500000000000000000000000001320320510200172635ustar00rootroot00000000000000asyncssh-1.11.1/docs/rftheme/static/rftheme.css_t000066400000000000000000000006141320320510200217530ustar00rootroot00000000000000@import url("classic.css"); .tight-list * { line-height: 110% !important; margin: 0 0 3px !important; } div.sphinxsidebar { top: 0; } div.sphinxsidebarwrapper { padding-top: 8px; } div.body p, div.body dd, div.body li { text-align: left; } tt, .note tt { font-size: 1.15em; background: none; } div.body p.rubric { font-size: 1.3em; margin: 15px 0 5px; } asyncssh-1.11.1/docs/rftheme/theme.conf000066400000000000000000000001131320320510200177400ustar00rootroot00000000000000[theme] inherit = classic stylesheet = rftheme.css pygments_style = sphinx asyncssh-1.11.1/docs/rtd-req.txt000066400000000000000000000000261320320510200164570ustar00rootroot00000000000000cryptography >= 0.6.1 asyncssh-1.11.1/examples/000077500000000000000000000000001320320510200152305ustar00rootroot00000000000000asyncssh-1.11.1/examples/callback_client.py000077500000000000000000000024721320320510200207040ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys class MySSHClientSession(asyncssh.SSHClientSession): def data_received(self, data, datatype): print(data, end='') def connection_lost(self, exc): if exc: print('SSH session error: ' + str(exc), file=sys.stderr) class MySSHClient(asyncssh.SSHClient): def connection_made(self, conn): print('Connection made to %s.' % conn.get_extra_info('peername')[0]) def auth_completed(self): print('Authentication successful.') async def run_client(): conn, client = await asyncssh.create_connection(MySSHClient, 'localhost') async with conn: chan, session = await conn.create_session(MySSHClientSession, 'ls abc') await chan.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/callback_client2.py000077500000000000000000000020471320320510200207640ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys class MySSHClientSession(asyncssh.SSHClientSession): def data_received(self, data, datatype): print(data, end='') def connection_lost(self, exc): if exc: print('SSH session error: ' + str(exc), file=sys.stderr) async def run_client(): async with asyncssh.connect('localhost') as conn: chan, session = await conn.create_session(MySSHClientSession, 'ls abc') await chan.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/callback_client3.py000077500000000000000000000022401320320510200207600ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys class MySSHClientSession(asyncssh.SSHClientSession): def data_received(self, data, datatype): if datatype == asyncssh.EXTENDED_DATA_STDERR: print(data, end='', file=sys.stderr) else: print(data, end='') def connection_lost(self, exc): if exc: print('SSH session error: ' + str(exc), file=sys.stderr) async def run_client(): async with asyncssh.connect('localhost') as conn: chan, session = await conn.create_session(MySSHClientSession, 'ls abc') await chan.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/callback_math_server.py000077500000000000000000000042771320320510200217520ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHServerSession(asyncssh.SSHServerSession): def __init__(self): self._input = '' self._total = 0 def connection_made(self, chan): self._chan = chan def shell_requested(self): return True def session_started(self): self._chan.write('Enter numbers one per line, or EOF when done:\n') def data_received(self, data, datatype): self._input += data lines = self._input.split('\n') for line in lines[:-1]: try: if line: self._total += int(line) except ValueError: self._chan.write_stderr('Invalid number: %s\n' % line) self._input = lines[-1] def eof_received(self): self._chan.write('Total = %s\n' % self._total) self._chan.exit(0) def break_received(self, msec): self.eof_received() class MySSHServer(asyncssh.SSHServer): def session_requested(self): return MySSHServerSession() async def start_server(): await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/chat_server.py000077500000000000000000000043341320320510200201160ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2016-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class ChatClient: _clients = [] def __init__(self, process): self._process = process @classmethod async def handle_client(cls, process): await cls(process).run() def write(self, msg): self._process.stdout.write(msg) def broadcast(self, msg): for client in self._clients: if client != self: client.write(msg) async def run(self): self.write('Welcome to chat!\n\n') self.write('Enter your name: ') name = (await self._process.stdin.readline()).rstrip('\n') self.write('\n%d other users are connected.\n\n' % len(self._clients)) self._clients.append(self) self.broadcast('*** %s has entered chat ***\n' % name) try: async for line in self._process.stdin: self.broadcast('%s: %s' % (name, line)) except asyncssh.BreakReceived: pass self.broadcast('*** %s has left chat ***\n' % name) self._clients.remove(self) async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=ChatClient.handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/check_exit_status.py000077500000000000000000000017501320320510200213210ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: result = await conn.run('ls abc') if result.exit_status == 0: print(result.stdout, end='') else: print(result.stderr, end='', file=sys.stderr) print('Program exited with status %d' % result.exit_status, file=sys.stderr) try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/chroot_sftp_server.py000077500000000000000000000026451320320510200215340ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, os, sys class MySFTPServer(asyncssh.SFTPServer): def __init__(self, conn): root = '/tmp/sftp/' + conn.get_extra_info('username') os.makedirs(root, exist_ok=True) super().__init__(conn, chroot=root) async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', sftp_factory=MySFTPServer) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/direct_client.py000077500000000000000000000024731320320510200204230ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys class MySSHTCPSession(asyncssh.SSHTCPSession): def data_received(self, data, datatype): # We use sys.stdout.buffer here because we're writing bytes sys.stdout.buffer.write(data) def connection_lost(self, exc): if exc: print('Direct connection error:', str(exc), file=sys.stderr) async def run_client(): async with asyncssh.connect('localhost') as conn: chan, session = await conn.create_connection(MySSHTCPSession, 'www.google.com', 80) # By default, TCP connections send and receive bytes chan.write(b'HEAD / HTTP/1.0\r\n\r\n') chan.write_eof() await chan.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/direct_server.py000077500000000000000000000033651320320510200204540ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHTCPSession(asyncssh.SSHTCPSession): def connection_made(self, chan): self._chan = chan def data_received(self, data, datatype): self._chan.write(data) class MySSHServer(asyncssh.SSHServer): def connection_requested(self, dest_host, dest_port, orig_host, orig_port): if dest_port == 7: return MySSHTCPSession() else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Only echo connections allowed') async def start_server(): await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/editor.py000077500000000000000000000033521320320510200170760ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_client(process): process.stdout.write('Welcome to my SSH server, %s!\n\n' % process.channel.get_extra_info('username')) process.channel.set_echo(False) process.stdout.write('Tell me a secret: ') secret = await process.stdin.readline() process.stdin.channel.set_line_mode(False) process.stdout.write('\nYour secret is safe with me! ' 'Press any key to exit...') await process.stdin.read(1) process.stdout.write('\n') process.exit(0) async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/gather_results.py000077500000000000000000000024211320320510200206370ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh async def run_client(host, command): async with asyncssh.connect(host) as conn: return await conn.run(command) async def run_multiple_clients(): # Put your lists of hosts here hosts = 5 * ['localhost'] tasks = (run_client(host, 'ls abc') for host in hosts) results = await asyncio.gather(*tasks, return_exceptions=True) for i, result in enumerate(results, 1): if isinstance(result, Exception): print('Task %d failed: %s' % (i, str(result))) elif result.exit_status != 0: print('Task %d exited with status %s:' % (i, result.exit_status)) print(result.stderr, end='') else: print('Task %d succeeded:' % i) print(result.stdout, end='') print(75*'-') asyncio.get_event_loop().run_until_complete(run_multiple_clients()) asyncssh-1.11.1/examples/listening_client.py000077500000000000000000000024451320320510200211440ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys class MySSHTCPSession(asyncssh.SSHTCPSession): def connection_made(self, chan): self._chan = chan def data_received(self, data, datatype): self._chan.write(data) def connection_requested(orig_host, orig_port): print('Connection received from %s, port %s' % (orig_host, orig_port)) return MySSHTCPSession() async def run_client(): async with asyncssh.connect('localhost') as conn: server = await conn.create_server(connection_requested, '', 8888, encoding='utf-8') if server: await server.wait_closed() else: print('Listener couldn''t be opened.', file=sys.stderr) try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/local_forwarding_client.py000077500000000000000000000014631320320510200224630ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: listener = await conn.forward_local_port('', 8080, 'www.google.com', 80) await listener.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/local_forwarding_client2.py000077500000000000000000000015571320320510200225510ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: listener = await conn.forward_local_port('', 0, 'www.google.com', 80) print('Listening on port %s...' % listener.get_port()) await listener.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/local_forwarding_server.py000077500000000000000000000030671320320510200225150ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHServer(asyncssh.SSHServer): def connection_requested(self, dest_host, dest_port, orig_host, orig_port): if dest_port == 80: return True else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Only connections to port 80 are allowed') async def start_server(): await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/math_client.py000077500000000000000000000016721320320510200201020ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: async with conn.create_process('bc') as process: for op in ['2+2', '1*2*3*4', '2^32']: process.stdin.write(op + '\n') result = await process.stdout.readline() print(op, '=', result, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/math_server.py000077500000000000000000000033271320320510200201310ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_client(process): process.stdout.write('Enter numbers one per line, or EOF when done:\n') total = 0 try: async for line in process.stdin: line = line.rstrip('\n') if line: try: total += int(line) except ValueError: process.stderr.write('Invalid number: %s\n' % line) except asyncssh.BreakReceived: pass process.stdout.write('Total = %s\n' % total) process.exit(0) async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/redirect_input.py000077500000000000000000000014101320320510200206210ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: await conn.run('tail -r', input='1\n2\n3\n', stdout='/tmp/stdout') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/redirect_local_pipe.py000077500000000000000000000017031320320510200215760ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, subprocess, sys async def run_client(): async with asyncssh.connect('localhost') as conn: local_proc = subprocess.Popen(r'echo "1\n2\n3"', shell=True, stdout=subprocess.PIPE) remote_result = await conn.run('tail -r', stdin=local_proc.stdout) print(remote_result.stdout, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/redirect_remote_pipe.py000077500000000000000000000015521320320510200220010ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: proc1 = await conn.create_process(r'echo "1\n2\n3"') proc2_result = await conn.run('tail -r', stdin=proc1.stdout) print(proc2_result.stdout, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/redirect_server.py000077500000000000000000000030771320320510200210030ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, subprocess, sys async def handle_client(process): bc_proc = subprocess.Popen('bc', shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) await process.redirect(stdin=bc_proc.stdin, stdout=bc_proc.stdout, stderr=bc_proc.stderr) await process.stdout.drain() process.exit(0) async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/remote_forwarding_client.py000077500000000000000000000014571320320510200226670ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: listener = await conn.forward_remote_port('', 8080, 'localhost', 80) await listener.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/remote_forwarding_client2.py000077500000000000000000000022041320320510200227400ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys def connection_requested(orig_host, orig_port): global conn if orig_host in ('127.0.0.1', '::1'): return conn.forward_connection('localhost', 80) else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Connections only allowed from localhost') async def run_client(): global conn async with asyncssh.connect('localhost') as conn: listener = await conn.create_server(connection_requested, '', 8080) await listener.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/remote_forwarding_server.py000077500000000000000000000025231320320510200227120ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHServer(asyncssh.SSHServer): def server_requested(self, listen_host, listen_port): return listen_port == 8080 async def start_server(): await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/scp_client.py000077500000000000000000000012671320320510200177360ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): await asyncssh.scp('localhost:example.txt', '.') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SFTP operation failed: ' + str(exc)) asyncssh-1.11.1/examples/set_environment.py000077500000000000000000000015371320320510200210320ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: result = await conn.run('env', env={'LANG': 'en_GB', 'LC_COLLATE': 'C'}) print(result.stdout, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/set_terminal.py000077500000000000000000000016101320320510200202710ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: result = await conn.run('echo $TERM; stty size', term_type='xterm-color', term_size=(80, 24)) print(result.stdout, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/sftp_client.py000077500000000000000000000014341320320510200201210ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2015-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: async with conn.start_sftp_client() as sftp: await sftp.get('example.txt') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SFTP operation failed: ' + str(exc)) asyncssh-1.11.1/examples/show_environment.py000077500000000000000000000031311320320510200212070ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_client(process): if process.env: keywidth = max(map(len, process.env.keys()))+1 process.stdout.write('Environment:\n') for key, value in process.env.items(): process.stdout.write(' %-*s %s\n' % (keywidth, key+':', value)) process.exit(0) else: process.stderr.write('No environment sent.\n') process.exit(1) async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/show_terminal.py000077500000000000000000000041161320320510200204620ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_client(process): term_type = process.get_terminal_type() width, height, pixwidth, pixheight = process.get_terminal_size() process.stdout.write('Terminal type: %s, size: %sx%s' % (term_type, width, height)) if pixwidth and pixheight: process.stdout.write(' (%sx%s pixels)' % (pixwidth, pixheight)) process.stdout.write('\nTry resizing your window!\n') while not process.stdin.at_eof(): try: await process.stdin.read() except asyncssh.TerminalSizeChanged as exc: process.stdout.write('New window size: %sx%s' % (exc.width, exc.height)) if exc.pixwidth and exc.pixheight: process.stdout.write(' (%sx%s pixels)' % (exc.pixwidth, exc.pixheight)) process.stdout.write('\n') async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/simple_cert_server.py000077500000000000000000000026021320320510200215010ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys def handle_client(process): process.stdout.write('Welcome to my SSH server, %s!\n' % process.channel.get_extra_info('username')) process.exit(0) async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/simple_client.py000077500000000000000000000014301320320510200204320ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: result = await conn.run('ls abc', check=True) print(result.stdout, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/simple_keyed_server.py000077500000000000000000000033441320320510200216510ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # Authentication requires the directory authorized_keys to exist with # files in it named based on the username containing the client keys # and certificate authority keys which are accepted for that user. import asyncio, asyncssh, sys def handle_client(process): process.stdout.write('Welcome to my SSH server, %s!\n' % process.channel.get_extra_info('username')) process.exit(0) class MySSHServer(asyncssh.SSHServer): def connection_made(self, conn): self._conn = conn def begin_auth(self, username): try: self._conn.set_authorized_keys('authorized_keys/%s' % username) except IOError: pass return True async def start_server(): await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/simple_scp_server.py000077500000000000000000000023231320320510200213310ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2015-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', sftp_factory=True, allow_scp=True) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/simple_server.py000077500000000000000000000041511320320510200204650ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. import asyncio, asyncssh, crypt, sys passwords = {'guest': '', # guest account with no password 'user123': 'qV2iEadIGV2rw' # password of 'secretpw' } def handle_client(process): process.stdout.write('Welcome to my SSH server, %s!\n' % process.channel.get_extra_info('username')) process.exit(0) class MySSHServer(asyncssh.SSHServer): def connection_made(self, conn): print('SSH connection received from %s.' % conn.get_extra_info('peername')[0]) def connection_lost(self, exc): if exc: print('SSH connection error: ' + str(exc), file=sys.stderr) else: print('SSH connection closed.') def begin_auth(self, username): # If the user's password is the empty string, no auth is required return passwords.get(username) != '' def password_auth_supported(self): return True def validate_password(self, username, password): pw = passwords.get(username, '*') return crypt.crypt(password, pw) == pw async def start_server(): await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], process_factory=handle_client) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/simple_sftp_server.py000077500000000000000000000023031320320510200215160ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2015-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def start_server(): await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', sftp_factory=True) loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/stream_direct_client.py000077500000000000000000000020471320320510200217730ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: reader, writer = await conn.open_connection('www.google.com', 80) # By default, TCP connections send and receive bytes writer.write(b'HEAD / HTTP/1.0\r\n\r\n') writer.write_eof() # We use sys.stdout.buffer here because we're writing bytes response = await reader.read() sys.stdout.buffer.write(response) try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/examples/stream_direct_server.py000077500000000000000000000034401320320510200220210ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_connection(reader, writer): while not reader.at_eof(): data = await reader.read(8192) try: writer.write(data) except BrokenPipeError: break writer.close() class MySSHServer(asyncssh.SSHServer): def connection_requested(self, dest_host, dest_port, orig_host, orig_port): if dest_port == 7: return handle_connection else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Only echo connections allowed') async def start_server(): await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.get_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-1.11.1/examples/stream_listening_client.py000077500000000000000000000022401320320510200225100ustar00rootroot00000000000000#!/usr/bin/env python3.5 # # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def handle_connection(reader, writer): while not reader.at_eof(): data = await reader.read(8192) writer.write(data) writer.close() def connection_requested(orig_host, orig_port): print('Connection received from %s, port %s' % (orig_host, orig_port)) return handle_connection async def run_client(): async with asyncssh.connect('localhost') as conn: server = await conn.start_server(connection_requested, '', 8888, encoding='utf-8') await server.wait_closed() try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-1.11.1/pylintrc000066400000000000000000000257421320320510200152130ustar00rootroot00000000000000[MASTER] # Specify a configuration file. #rcfile= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Add files or directories to the blacklist. They should be base names, not # paths. ignore= # Pickle collected data for later comparisons. persistent=yes # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=1 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code extension-pkg-whitelist= # Allow optimization of some AST trees. This will activate a peephole AST # optimizer, which will apply various small optimizations. For instance, it can # be used to obtain the result of joining multiple strings with the addition # operator. Joining a lot of strings can lead to a maximum recursion error in # Pylint and this flag can prevent that. It has one side effect, the resulting # AST will be different than the one from reality. optimize-ast=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time. See also the "--disable" option for examples. #enable= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=fixme,locally-enabled,locally-disabled,redefined-variable-type,too-many-boolean-expressions,too-many-arguments [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". files-output=no # Tells whether to display a full report or only the messages reports=yes # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [BASIC] # List of builtins function names that should not be used, separated by a comma bad-functions=map,filter # Good variable names which should always be accepted, separated by a comma good-names=a,av,b,c,ca,ch,cn,f,fs,g,h,i,ip,iv,j,k,l,n,r,s,sa,t,v,x,y,_ # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for argument names argument-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for variable names variable-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ # Naming hint for class names class-name-hint=[A-Z_][a-zA-Z0-9]+$ # Regular expression matching correct constant names const-rgx=(([A-Za-z_][A-Za-z0-9_]*)|(__.*__))$ # Naming hint for constant names const-name-hint=(([A-Z][A-Z-9_]*)|(__.*__))$ # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for method names method-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Naming hint for module names module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Regular expression matching correct function names function-rgx=[A-Za-z_][A-Za-z0-9_]{2,75}$ # Naming hint for function names function-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for attribute names attr-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ # Naming hint for inline iteration names inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,75}|(__.*__))$ # Naming hint for class attribute names class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,75}|(__.*__))$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=__.*__ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 [FORMAT] # Maximum number of characters on a single line. max-line-length=100 # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no # List of optional constructs for which whitespace checking is disabled no-space-check=trailing-comma,dict-separator # Maximum number of lines in a module max-module-lines=7500 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME,XXX,TODO [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=8 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [TYPECHECK] # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis ignored-modules= # List of classes names for which member attributes should not be checked # (useful for classes with attributes dynamically set). ignored-classes=SQLObject # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E0201 when accessed. Python regular # expressions are accepted. generated-members=REQUEST,acl_users,aq_parent [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=_$|dummy # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__,__new__,setUp # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make [DESIGN] # Maximum number of arguments for function / method max-args=25 # Argument names that match this expression will be ignored. Default to name # with leading underscore ignored-argument-names=_.* # Maximum number of locals for function / method body max-locals=50 # Maximum number of return / yield for function / method body max-returns=10 # Maximum number of branch for function / method body max-branches=50 # Maximum number of statements in function / method body max-statements=100 # Maximum number of parents for a class (see R0901). max-parents=10 # Maximum number of attributes for a class (see R0902). max-attributes=100 # Minimum number of public methods for a class (see R0903). min-public-methods=0 # Maximum number of public methods for a class (see R0904). max-public-methods=50 [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=stringprep,optparse # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=Exception asyncssh-1.11.1/setup.py000077500000000000000000000046201320320510200151310ustar00rootroot00000000000000#!/usr/bin/env python3.5 # Copyright (c) 2013-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """AsyncSSH: Asynchronous SSHv2 client and server library AsyncSSH is a Python package which provides an asynchronous client and server implementation of the SSHv2 protocol on top of the Python asyncio framework. It requires Python 3.4 or later and either the PyCA library or the PyCrypto library for some cryptographic functions. """ from os import path from setuptools import setup base_dir = path.abspath(path.dirname(__file__)) doclines = __doc__.split('\n', 1) with open(path.join(base_dir, 'README.rst')) as desc: long_description = desc.read() with open(path.join(base_dir, 'asyncssh', 'version.py')) as version: exec(version.read()) setup(name = 'asyncssh', version = __version__, author = __author__, author_email = __author_email__, url = __url__, license = 'Eclipse Public License v1.0', description = doclines[0], long_description = long_description, platforms = 'Any', install_requires = ['cryptography >= 1.5'], extras_require = { 'bcrypt': ['bcrypt >= 3.0.0'], 'gssapi': ['gssapi >= 1.2.0'], 'libnacl': ['libnacl >= 1.4.2'], 'pyOpenSSL': ['pyOpenSSL >= 17.0.0'], 'pypiwin32': ['pypiwin32 >= 219'] }, packages = ['asyncssh', 'asyncssh.crypto', 'asyncssh.crypto.pyca'], scripts = [], test_suite = 'tests', classifiers = [ 'Development Status :: 5 - Production/Stable', 'Environment :: Console', 'Intended Audience :: Developers', 'License :: OSI Approved', 'Operating System :: MacOS :: MacOS X', 'Operating System :: POSIX', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Topic :: Internet', 'Topic :: Security :: Cryptography', 'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: System :: Networking']) asyncssh-1.11.1/tests/000077500000000000000000000000001320320510200145545ustar00rootroot00000000000000asyncssh-1.11.1/tests/__init__.py000066400000000000000000000006631320320510200166720ustar00rootroot00000000000000# Copyright (c) 2014-2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH""" asyncssh-1.11.1/tests/gss_stub.py000066400000000000000000000023031320320510200167550ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub GSS module for unit tests""" def step(host, token): """Perform next step in GSS authentication""" complete = False if token == b'errtok': return token, complete elif ((token is None and 'empty_init' in host) or (token == b'1' and 'empty_continue' in host)): return b'', complete elif token == b'0': if 'continue_token' in host: token = b'continue' else: complete = True token = b'extra' if 'extra_token' in host else None elif token: token = bytes((token[0]-1,)) else: token = host[0].encode('ascii') if token == b'0': if 'step_error' in host: return (b'errtok' if 'errtok' in host else b'error'), complete complete = True return token, complete asyncssh-1.11.1/tests/gssapi_stub.py000066400000000000000000000063541320320510200174610ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub GSSAPI module for unit tests""" from asyncssh.gss import GSSError from .gss_stub import step class Name: """Stub class for GSS principal name""" def __init__(self, base, name_type=None): # pylint: disable=unused-argument if 'init_error' in base: raise GSSError(99, 99) self.host = base[5:] class Credentials: """Stub class for GSS credentials""" def __init__(self, name=None, usage=None): self.host = name.host if name else '' self.server = usage == 'accept' @property def mechs(self): """Return GSS mechanisms available for this host""" if self.server: return [0] if 'unknown_mech' in self.host else [1, 2] else: return [2] class RequirementFlag: """Stub class for GSS requirement flags""" mutual_authentication = 'mutual_auth' integrity = 'integrity' delegate_to_peer = 'delegate' class SecurityContext: """Stub class for GSS security context""" def __init__(self, name=None, creds=None, flags=None): host = creds.host if creds.server else name.host if flags is None: flags = set((RequirementFlag.mutual_authentication, RequirementFlag.integrity)) if ((creds.server and 'no_server_integrity' in host) or (not creds.server and 'no_client_integrity' in host)): flags.remove(RequirementFlag.integrity) self._host = host self._server = creds.server self._actual_flags = flags self._complete = False @property def complete(self): """Return whether or not GSS negotiation is complete""" return self._complete @property def actual_flags(self): """Return flags set on this context""" return self._actual_flags @property def initiator_name(self): """Return user principal associated with this context""" return 'user@TEST' @property def target_name(self): """Return host principal associated with this context""" return 'host@TEST' def step(self, token=None): """Perform next step in GSS security exchange""" token, complete = step(self._host, token) if complete: self._complete = True if token == b'error': raise GSSError(99, 99) elif token == b'errtok': raise GSSError(99, 99, token) else: return token def get_signature(self, data): """Sign a block of data""" # pylint: disable=no-self-use,unused-argument return b'fail' if 'fail' in self._host else 'succeed' def verify_signature(self, data, sig): """Verify a signature for a block of data""" # pylint: disable=no-self-use,unused-argument if sig == b'fail': raise GSSError(99, 99) asyncssh-1.11.1/tests/server.py000066400000000000000000000250631320320510200164420ustar00rootroot00000000000000# Copyright (c) 2016-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH server used for unit tests""" import asyncio import os import shutil import signal import socket import subprocess import asyncssh from asyncssh.misc import async_context_manager from .util import AsyncTestCase, run, x509_available class Server(asyncssh.SSHServer): """Unit test SSH server""" def __init__(self): self._conn = None def connection_made(self, conn): """Record connection object for later use""" self._conn = conn def begin_auth(self, username): """Handle client authentication request""" return username != 'guest' class ServerTestCase(AsyncTestCase): """Unit test class which starts an SSH server and agent""" # Pylint doesn't like mixed case method names, but this was chosen to # match the convention used in the unittest module. # pylint: disable=invalid-name _server = None _server_addr = None _server_port = None _agent_pid = None @classmethod @asyncio.coroutine def create_server(cls, server_factory=(), *, loop=(), server_host_keys=(), gss_host=None, **kwargs): """Create an SSH server for the tests to use""" if loop is (): loop = cls.loop if server_factory is (): server_factory = Server if server_host_keys is (): server_host_keys = ['skey'] return (yield from asyncssh.create_server( server_factory, port=0, family=socket.AF_INET, loop=loop, server_host_keys=server_host_keys, gss_host=gss_host, **kwargs)) @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server()) @classmethod @asyncio.coroutine def asyncSetUpClass(cls): """Set up keys, an SSH server, and an SSH agent for the tests to use""" # pylint: disable=too-many-statements ckey = asyncssh.generate_private_key('ssh-rsa') ckey.write_private_key('ckey') ckey.write_private_key('ckey_encrypted', passphrase='passphrase') ckey.write_public_key('ckey.pub') ckey_ecdsa = asyncssh.generate_private_key('ecdsa-sha2-nistp256') ckey_ecdsa.write_private_key('ckey_ecdsa') ckey_ecdsa.write_public_key('ckey_ecdsa.pub') ckey_cert = ckey.generate_user_certificate(ckey, 'name', principals=['ckey']) ckey_cert.write_certificate('ckey-cert.pub') skey = asyncssh.generate_private_key('ssh-rsa') skey.write_private_key('skey') skey.write_public_key('skey.pub') skey_ecdsa = asyncssh.generate_private_key('ecdsa-sha2-nistp256') skey_ecdsa.write_private_key('skey_ecdsa') skey_ecdsa.write_public_key('skey_ecdsa.pub') skey_cert = skey.generate_host_certificate(skey, 'name', principals=['127.0.0.1']) skey_cert.write_certificate('skey-cert.pub') exp_cert = skey.generate_host_certificate(skey, 'name', valid_after='-2d', valid_before='-1d') skey.write_private_key('exp_skey') exp_cert.write_certificate('exp_skey-cert.pub') if x509_available: # pragma: no branch ckey_x509_self = ckey_ecdsa.generate_x509_user_certificate( ckey_ecdsa, 'OU=name', principals=['ckey']) ckey_ecdsa.write_private_key('ckey_x509_self') ckey_x509_self.append_certificate('ckey_x509_self', 'pem') ckey_x509_self.write_certificate('ckey_x509_self.pem', 'pem') ckey_x509_self.write_certificate('ckey_x509_self.pub') skey_x509_self = skey_ecdsa.generate_x509_host_certificate( skey_ecdsa, 'OU=name', principals=['127.0.0.1']) skey_ecdsa.write_private_key('skey_x509_self') skey_x509_self.append_certificate('skey_x509_self', 'pem') skey_x509_self.write_certificate('skey_x509_self.pem', 'pem') root_ca_key = asyncssh.generate_private_key('ssh-rsa') root_ca_key.write_private_key('root_ca_key') root_ca_cert = root_ca_key.generate_x509_ca_certificate( root_ca_key, 'OU=RootCA', ca_path_len=1) root_ca_cert.write_certificate('root_ca_cert.pem', 'pem') root_ca_cert.write_certificate('root_ca_cert.pub') int_ca_key = asyncssh.generate_private_key('ssh-rsa') int_ca_key.write_private_key('int_ca_key') int_ca_cert = root_ca_key.generate_x509_ca_certificate( int_ca_key, 'OU=IntCA', 'OU=RootCA', ca_path_len=0) int_ca_cert.write_certificate('int_ca_cert.pem', 'pem') ckey_x509_chain = int_ca_key.generate_x509_user_certificate( ckey, 'OU=name', 'OU=IntCA', principals=['ckey']) ckey.write_private_key('ckey_x509_chain') ckey_x509_chain.append_certificate('ckey_x509_chain', 'pem') int_ca_cert.append_certificate('ckey_x509_chain', 'pem') ckey_x509_chain.write_certificate('ckey_x509_partial.pem', 'pem') skey_x509_chain = int_ca_key.generate_x509_host_certificate( skey, 'OU=name', 'OU=IntCA', principals=['127.0.0.1']) skey.write_private_key('skey_x509_chain') skey_x509_chain.append_certificate('skey_x509_chain', 'pem') int_ca_cert.append_certificate('skey_x509_chain', 'pem') root_hash = root_ca_cert.x509_cert.subject_hash os.mkdir('cert_path') shutil.copy('root_ca_cert.pem', os.path.join('cert_path', root_hash + '.0')) # Intentional hash mismatch shutil.copy('int_ca_cert.pem', os.path.join('cert_path', root_hash + '.1')) for f in ('ckey', 'ckey_ecdsa', 'skey', 'exp_skey', 'skey_ecdsa'): os.chmod(f, 0o600) os.mkdir('.ssh', 0o700) os.mkdir('.ssh/crt', 0o700) shutil.copy('ckey_ecdsa', os.path.join('.ssh', 'id_ecdsa')) shutil.copy('ckey_ecdsa.pub', os.path.join('.ssh', 'id_ecdsa.pub')) shutil.copy('ckey_encrypted', os.path.join('.ssh', 'id_rsa')) shutil.copy('ckey.pub', os.path.join('.ssh', 'id_rsa.pub')) with open('authorized_keys', 'w') as auth_keys: with open('ckey.pub') as ckey_pub: shutil.copyfileobj(ckey_pub, auth_keys) with open('ckey_ecdsa.pub') as ckey_ecdsa_pub: shutil.copyfileobj(ckey_ecdsa_pub, auth_keys) auth_keys.write('cert-authority,principals="ckey",' 'permitopen=:* ') with open('ckey.pub') as ckey_pub: shutil.copyfileobj(ckey_pub, auth_keys) if x509_available: # pragma: no branch with open('authorized_keys_x509', 'w') as auth_keys_x509: with open('ckey_x509_self.pub') as ckey_self_pub: shutil.copyfileobj(ckey_self_pub, auth_keys_x509) auth_keys_x509.write('cert-authority,principals="ckey" ') with open('root_ca_cert.pub') as root_pub: shutil.copyfileobj(root_pub, auth_keys_x509) cls._server = yield from cls.start_server() sock = cls._server.sockets[0] cls._server_addr = '127.0.0.1' cls._server_port = sock.getsockname()[1] host = '[%s]:%s ' % (cls._server_addr, cls._server_port) with open('known_hosts', 'w') as known_hosts: known_hosts.write(host) with open('skey.pub') as skey_pub: shutil.copyfileobj(skey_pub, known_hosts) known_hosts.write('@cert-authority ' + host) with open('skey.pub') as skey_pub: shutil.copyfileobj(skey_pub, known_hosts) shutil.copy('known_hosts', os.path.join('.ssh', 'known_hosts')) os.environ['LOGNAME'] = 'guest' os.environ['HOME'] = '.' if 'DISPLAY' in os.environ: # pragma: no cover del os.environ['DISPLAY'] if 'SSH_ASKPASS' in os.environ: # pragma: no cover del os.environ['SSH_ASKPASS'] if 'SSH_AUTH_SOCK' in os.environ: # pragma: no cover del os.environ['SSH_AUTH_SOCK'] if 'XAUTHORITY' in os.environ: # pragma: no cover del os.environ['XAUTHORITY'] try: output = run('ssh-agent -a agent 2>/dev/null') except subprocess.CalledProcessError: # pragma: no cover cls._agent_pid = None else: cls._agent_pid = int(output.splitlines()[2].split()[3][:-1]) os.environ['SSH_AUTH_SOCK'] = 'agent' agent = yield from asyncssh.connect_agent() yield from agent.add_keys([ckey_ecdsa, (ckey, ckey_cert)]) agent.close() @classmethod @asyncio.coroutine def asyncTearDownClass(cls): """Shut down test server and agent""" # Wait a bit for existing tasks to exit yield from asyncio.sleep(1) cls._server.close() yield from cls._server.wait_closed() if cls._agent_pid: # pragma: no branch os.kill(cls._agent_pid, signal.SIGTERM) # pylint: enable=invalid-name def agent_available(self): """Return whether SSH agent is available""" return bool(self._agent_pid) @asyncio.coroutine def create_connection(self, client_factory, loop=(), gss_host=None, **kwargs): """Create a connection to the test server""" if loop is (): loop = self.loop return (yield from asyncssh.create_connection(client_factory, self._server_addr, self._server_port, loop=loop, gss_host=gss_host, **kwargs)) @async_context_manager def connect(self, **kwargs): """Open a connection to the test server""" conn, _ = yield from self.create_connection(None, **kwargs) return conn asyncssh-1.11.1/tests/sspi_stub.py000066400000000000000000000064171320320510200171510ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub SSPI module for unit tests""" from asyncssh.gss_win32 import ASC_RET_INTEGRITY, ISC_RET_INTEGRITY from asyncssh.gss_win32 import SECPKG_ATTR_NATIVE_NAMES, SSPIError from .gss_stub import step class SSPIBuffer: """Stub class for SSPI buffer""" def __init__(self, data): self._data = data @property def Buffer(self): # pylint: disable=invalid-name """Return the data in the buffer""" return self._data class SSPIContext: """Stub class for SSPI security context""" def QueryContextAttributes(self, attr): # pylint: disable=invalid-name """Return principal information associated with this context""" # pylint: disable=no-self-use if attr == SECPKG_ATTR_NATIVE_NAMES: return ['user@TEST', 'host@TEST'] class SSPIAuth: """Stub class for SSPI authentication""" def __init__(self, package=None, spn=None, targetspn=None, scflags=None): # pylint: disable=unused-argument host = spn or targetspn if 'init_error' in host: raise SSPIError('Authentication initialization error') if targetspn and 'no_client_integrity' in host: scflags &= ~ISC_RET_INTEGRITY elif spn and 'no_server_integrity' in host: scflags &= ~ASC_RET_INTEGRITY self._host = host[5:] self._flags = scflags self._ctxt = SSPIContext() self._complete = False self._error = False @property def authenticated(self): """Return whether authentication is complete""" return self._complete @property def ctxt(self): """Return authentication context""" return self._ctxt @property def ctxt_attr(self): """Return authentication flags""" return self._flags def reset(self): """Reset SSPI security context""" self._complete = False def authorize(self, token): """Perform next step in SSPI authentication""" if self._error: self._error = False raise SSPIError('Token authentication errror') new_token, complete = step(self._host, token) if complete: self._complete = True if new_token in (b'error', b'errtok'): if token: raise SSPIError('Token authentication errror') else: self._error = True return True, [SSPIBuffer(b'')] else: return bool(new_token), [SSPIBuffer(new_token)] def sign(self, data): """Sign a block of data""" # pylint: disable=no-self-use,unused-argument return b'fail' if 'fail' in self._host else 'succeed' def verify(self, data, sig): """Verify a signature for a block of data""" # pylint: disable=no-self-use,unused-argument if sig == b'fail': raise SSPIError('Signature verification error') asyncssh-1.11.1/tests/test_agent.py000066400000000000000000000273321320320510200172720ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH ssh-agent client""" import asyncio import functools import os import signal import subprocess import unittest import asyncssh from asyncssh.agent import SSH_AGENT_SUCCESS, SSH_AGENT_FAILURE from asyncssh.packet import Byte, String from .util import AsyncTestCase, asynctest, libnacl_available, run def agent_test(func): """Decorator for running SSH agent tests""" @asynctest @functools.wraps(func) def agent_wrapper(self): """Run a test coroutine after connecting to an SSH agent""" agent = yield from asyncssh.connect_agent() yield from agent.remove_all() yield from asyncio.coroutine(func)(self, agent) agent.close() return agent_wrapper class _Agent: """Mock SSH agent for testing error cases""" def __init__(self, response): self._response = response self._path = None self._server = None @asyncio.coroutine def start(self, path): """Start a new mock SSH agent""" self._path = path # pylint doesn't think start_unix_server exists # pylint: disable=no-member self._server = \ yield from asyncio.start_unix_server(self.process_request, path) @asyncio.coroutine def process_request(self, _reader, writer): """Process a request sent to the mock SSH agent""" yield from _reader.readexactly(4) writer.write(self._response) writer.close() @asyncio.coroutine def stop(self): """Shut down the mock SSH agent""" self._server.close() yield from self._server.wait_closed() os.remove(self._path) class _TestAPI(AsyncTestCase): """Unit tests for AsyncSSH API""" _agent_pid = None _public_keys = {} # Pylint doesn't like mixed case method names, but this was chosen to # match the convention used in the unittest module. # pylint: disable=invalid-name @staticmethod def set_askpass(status): """Set return status for ssh-askpass""" with open('ssh-askpass', 'w') as f: f.write('#!/bin/sh\nexit %d\n' % status) os.chmod('ssh-askpass', 0o755) @classmethod @asyncio.coroutine def asyncSetUpClass(cls): """Set up keys and an SSH server for the tests to use""" os.environ['DISPLAY'] = '' os.environ['HOME'] = '.' os.environ['SSH_ASKPASS'] = os.path.join(os.getcwd(), 'ssh-askpass') try: output = run('ssh-agent -a agent 2>/dev/null') except subprocess.CalledProcessError: # pragma: no cover raise unittest.SkipTest('ssh-agent not available') cls._agent_pid = int(output.splitlines()[2].split()[3][:-1]) os.environ['SSH_AUTH_SOCK'] = 'agent' @classmethod @asyncio.coroutine def asyncTearDownClass(cls): """Shut down agents""" os.kill(cls._agent_pid, signal.SIGTERM) # pylint: enable=invalid-name @agent_test def test_connection(self, agent): """Test opening a connection to the agent""" self.assertIsNotNone(agent) @asynctest def test_connection_failed(self): """Test failure in opening a connection to the agent""" self.assertIsNone((yield from asyncssh.connect_agent('xxx'))) @asynctest def test_no_auth_sock(self): """Test failure when no auth sock is set""" del os.environ['SSH_AUTH_SOCK'] self.assertIsNone((yield from asyncssh.connect_agent())) os.environ['SSH_AUTH_SOCK'] = 'agent' @asynctest def test_explicit_loop(self): """Test passing the event loop explicitly""" loop = asyncio.get_event_loop() agent = yield from asyncssh.connect_agent(loop=loop) self.assertIsNotNone(agent) agent.close() @agent_test def test_get_keys(self, agent): """Test getting keys from the agent""" keys = yield from agent.get_keys() self.assertEqual(len(keys), len(self._public_keys)) @agent_test def test_sign(self, agent): """Test signing a block of data using the agent""" algs = ['ssh-dss', 'ssh-rsa', 'ecdsa-sha2-nistp256'] if libnacl_available: # pragma: no branch algs.append('ssh-ed25519') for alg_name in algs: key = asyncssh.generate_private_key(alg_name) pubkey = key.convert_to_public() cert = key.generate_user_certificate(key, 'name') yield from agent.add_keys([(key, cert)]) agent_keys = yield from agent.get_keys() for agent_key in agent_keys: sig = yield from agent_key.sign(b'test') self.assertTrue(pubkey.verify(b'test', sig)) yield from agent.remove_keys(agent_keys) @agent_test def test_reconnect(self, agent): """Test reconnecting to the agent after closing it""" key = asyncssh.generate_private_key('ssh-rsa') pubkey = key.convert_to_public() yield from agent.add_keys([key]) agent_keys = yield from agent.get_keys() agent.close() for agent_key in agent_keys: sig = yield from agent_key.sign(b'test') self.assertTrue(pubkey.verify(b'test', sig)) @agent_test def test_add_remove_keys(self, agent): """Test adding and removing keys""" yield from agent.add_keys() agent_keys = yield from agent.get_keys() self.assertEqual(len(agent_keys), 0) key = asyncssh.generate_private_key('ssh-rsa') yield from agent.add_keys([key]) agent_keys = yield from agent.get_keys() self.assertEqual(len(agent_keys), 1) yield from agent.remove_keys(agent_keys) agent_keys = yield from agent.get_keys() self.assertEqual(len(agent_keys), 0) yield from agent.add_keys([key]) agent_keys = yield from agent.get_keys() self.assertEqual(len(agent_keys), 1) yield from agent_keys[0].remove() agent_keys = yield from agent.get_keys() self.assertEqual(len(agent_keys), 0) yield from agent.add_keys([key], lifetime=1) agent_keys = yield from agent.get_keys() self.assertEqual(len(agent_keys), 1) yield from asyncio.sleep(2) agent_keys = yield from agent.get_keys() self.assertEqual(len(agent_keys), 0) @asynctest def test_add_remove_smartcard_keys(self): """Test adding and removing smart card keys""" mock_agent = _Agent(String(Byte(SSH_AGENT_SUCCESS))) yield from mock_agent.start('mock_agent') agent = yield from asyncssh.connect_agent('mock_agent') result = yield from agent.add_smartcard_keys('provider') self.assertIsNone(result) agent.close() yield from mock_agent.stop() mock_agent = _Agent(String(Byte(SSH_AGENT_SUCCESS))) yield from mock_agent.start('mock_agent') agent = yield from asyncssh.connect_agent('mock_agent') result = yield from agent.remove_smartcard_keys('provider') self.assertIsNone(result) agent.close() yield from mock_agent.stop() @agent_test def test_confirm(self, agent): """Test confirmation of key""" key = asyncssh.generate_private_key('ssh-rsa') pubkey = key.convert_to_public() yield from agent.add_keys([key], confirm=True) agent_keys = yield from agent.get_keys() self.set_askpass(1) for agent_key in agent_keys: with self.assertRaises(ValueError): sig = yield from agent_key.sign(b'test') self.set_askpass(0) for agent_key in agent_keys: sig = yield from agent_key.sign(b'test') self.assertTrue(pubkey.verify(b'test', sig)) @agent_test def test_lock(self, agent): """Test lock and unlock""" key = asyncssh.generate_private_key('ssh-rsa') pubkey = key.convert_to_public() yield from agent.add_keys([key]) agent_keys = yield from agent.get_keys() yield from agent.lock('passphrase') for agent_key in agent_keys: with self.assertRaises(ValueError): yield from agent_key.sign(b'test') yield from agent.unlock('passphrase') for agent_key in agent_keys: sig = yield from agent_key.sign(b'test') self.assertTrue(pubkey.verify(b'test', sig)) @asynctest def test_query_extensions(self): """Test query of supported extensions""" mock_agent = _Agent(String(Byte(SSH_AGENT_SUCCESS) + String('xxx'))) yield from mock_agent.start('mock_agent') agent = yield from asyncssh.connect_agent('mock_agent') extensions = yield from agent.query_extensions() self.assertEqual(extensions, ['xxx']) agent.close() yield from mock_agent.stop() mock_agent = _Agent(String(Byte(SSH_AGENT_SUCCESS) + String(b'\xff'))) yield from mock_agent.start('mock_agent') agent = yield from asyncssh.connect_agent('mock_agent') with self.assertRaises(ValueError): yield from agent.query_extensions() agent.close() yield from mock_agent.stop() mock_agent = _Agent(String(Byte(SSH_AGENT_FAILURE))) yield from mock_agent.start('mock_agent') agent = yield from asyncssh.connect_agent('mock_agent') extensions = yield from agent.query_extensions() self.assertEqual(extensions, []) agent.close() yield from mock_agent.stop() mock_agent = _Agent(String(b'\xff')) yield from mock_agent.start('mock_agent') agent = yield from asyncssh.connect_agent('mock_agent') with self.assertRaises(ValueError): yield from agent.query_extensions() agent.close() yield from mock_agent.stop() @agent_test def test_unknown_key(self, agent): """Test failure when signing with an unknown key""" key = asyncssh.generate_private_key('ssh-rsa') with self.assertRaises(ValueError): yield from agent.sign(key.get_ssh_public_key(), b'test') @agent_test def test_double_close(self, agent): """Test calling close more than once on the agent""" self.assertIsNotNone(agent) agent.close() @asynctest def test_errors(self): """Test getting error responses from SSH agent""" # pylint: disable=bad-whitespace key = asyncssh.generate_private_key('ssh-rsa') keypair = asyncssh.load_keypairs(key)[0] for response in (b'', String(b''), String(Byte(SSH_AGENT_FAILURE)), String(b'\xff')): mock_agent = _Agent(response) yield from mock_agent.start('mock_agent') agent = yield from asyncssh.connect_agent('mock_agent') for request in (agent.get_keys(), agent.sign(b'xxx', b'test'), agent.add_keys([key]), agent.add_smartcard_keys('xxx'), agent.remove_keys([keypair]), agent.remove_smartcard_keys('xxx'), agent.remove_all(), agent.lock('passphrase'), agent.unlock('passphrase')): with self.assertRaises(ValueError): yield from request agent.close() yield from mock_agent.stop() asyncssh-1.11.1/tests/test_asn1.py000066400000000000000000000201711320320510200170300ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for ASN.1 encoding and decoding""" import codecs import unittest from asyncssh.asn1 import der_encode, der_decode from asyncssh.asn1 import ASN1EncodeError, ASN1DecodeError from asyncssh.asn1 import BitString, IA5String, ObjectIdentifier from asyncssh.asn1 import RawDERObject, TaggedDERObject, PRIVATE class _TestASN1(unittest.TestCase): """Unit tests for ASN.1 module""" # pylint: disable=bad-whitespace tests = [ (None, '0500'), (False, '010100'), (True, '0101ff'), (0, '020100'), (127, '02017f'), (128, '02020080'), (256, '02020100'), (-128, '020180'), (-129, '0202ff7f'), (-256, '0202ff00'), (b'', '0400'), (b'\0', '040100'), (b'abc', '0403616263'), (127*b'\0', '047f' + 127*'00'), (128*b'\0', '048180' + 128*'00'), ('', '0c00'), ('\0', '0c0100'), ('abc', '0c03616263'), ((), '3000'), ((1,), '3003020101'), ((1, 2), '3006020101020102'), (frozenset(), '3100'), (frozenset({1}), '3103020101'), (frozenset({1, 2}), '3106020101020102'), (frozenset({-128, 127}), '310602017f020180'), (BitString(b''), '030100'), (BitString(b'\0', 7), '03020700'), (BitString(b'\x80', 7), '03020780'), (BitString(b'\x80', named=True), '03020780'), (BitString(b'\x81', named=True), '03020081'), (BitString(b'\x81\x00', named=True), '03020081'), (BitString(b'\x80', 6), '03020680'), (BitString(b'\x80'), '03020080'), (BitString(b'\x80\x00', 7), '0303078000'), (BitString(''), '030100'), (BitString('0'), '03020700'), (BitString('1'), '03020780'), (BitString('10'), '03020680'), (BitString('10000000'), '03020080'), (BitString('10000001'), '03020081'), (BitString('100000000'), '0303078000'), (IA5String(''), '1600'), (IA5String('\0'), '160100'), (IA5String('abc'), '1603616263'), (ObjectIdentifier('0.0'), '060100'), (ObjectIdentifier('1.2'), '06012a'), (ObjectIdentifier('1.2.840'), '06032a8648'), (ObjectIdentifier('2.5'), '060155'), (ObjectIdentifier('2.40'), '060178'), (TaggedDERObject(0, None), 'a0020500'), (TaggedDERObject(1, None), 'a1020500'), (TaggedDERObject(32, None), 'bf20020500'), (TaggedDERObject(128, None), 'bf8100020500'), (TaggedDERObject(0, None, PRIVATE), 'e0020500'), (RawDERObject(0, b'', PRIVATE), 'c000') ] encode_errors = [ (range, [1]), # Unsupported type (BitString, [b'', 1]), # Bit count with empty value (BitString, [b'', -1]), # Invalid unused bit count (BitString, [b'', 8]), # Invalid unused bit count (BitString, [b'0c0', 7]), # Unused bits not zero (BitString, ['', 1]), # Unused bits with string (BitString, [0]), # Invalid type (ObjectIdentifier, ['']), # Too few components (ObjectIdentifier, ['1']), # Too few components (ObjectIdentifier, ['-1.1']), # First component out of range (ObjectIdentifier, ['3.1']), # First component out of range (ObjectIdentifier, ['0.-1']), # Second component out of range (ObjectIdentifier, ['0.40']), # Second component out of range (ObjectIdentifier, ['1.-1']), # Second component out of range (ObjectIdentifier, ['1.40']), # Second component out of range (ObjectIdentifier, ['1.1.-1']), # Later component out of range (TaggedDERObject, [0, None, 99]), # Invalid ASN.1 class (RawDERObject, [0, None, 99]), # Invalid ASN.1 class ] decode_errors = [ '', # Incomplete data '01', # Incomplete data '0101', # Incomplete data '1f00', # Incomplete data '1f8000', # Incomplete data '1f0001', # Incomplete data '1f80', # Incomplete tag '0180', # Indefinite length '050001', # Unexpected bytes at end '2500', # Constructed null '050100', # Null with content '2100', # Constructed boolean '010102', # Boolean value not 0x00/0xff '2200', # Constructed integer '2400', # Constructed octet string '2c00', # Constructed UTF-8 string '1000', # Non-constructed sequence '1100', # Non-constructed set '2300', # Constructed bit string '03020800', # Invalid unused bit count '3600', # Constructed IA5 string '2600', # Constructed object identifier '0600', # Empty object identifier '06020080', # Invalid component '06020081' # Incomplete component ] # pylint: enable=bad-whitespace def test_asn1(self): """Unit test ASN.1 module""" for value, data in self.tests: data = codecs.decode(data, 'hex') with self.subTest(msg='encode', value=value): self.assertEqual(der_encode(value), data) with self.subTest(msg='decode', data=data): decoded_value = der_decode(data) self.assertEqual(decoded_value, value) self.assertEqual(hash(decoded_value), hash(value)) self.assertEqual(repr(decoded_value), repr(value)) self.assertEqual(str(decoded_value), str(value)) for cls, args in self.encode_errors: with self.subTest(msg='encode error', cls=cls.__name__, args=args): with self.assertRaises(ASN1EncodeError): der_encode(cls(*args)) for data in self.decode_errors: with self.subTest(msg='decode error', data=data): with self.assertRaises(ASN1DecodeError): der_decode(codecs.decode(data, 'hex')) asyncssh-1.11.1/tests/test_auth.py000066400000000000000000000565641320320510200171460ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for authentication""" import asyncio import unittest import asyncssh from asyncssh.auth import MSG_USERAUTH_PK_OK, lookup_client_auth from asyncssh.auth import get_server_auth_methods, lookup_server_auth from asyncssh.auth import MSG_USERAUTH_GSSAPI_RESPONSE from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE from asyncssh.constants import MSG_USERAUTH_SUCCESS from asyncssh.gss import GSSClient, GSSServer from asyncssh.misc import DisconnectError, PasswordChangeRequired from asyncssh.packet import SSHPacket, Boolean, Byte, NameList, String from asyncssh.public_key import SSHLocalKeyPair from .util import asynctest, gss_available, patch_gss from .util import AsyncTestCase, ConnectionStub class _AuthConnectionStub(ConnectionStub): """Connection stub class to test authentication""" def connection_lost(self, exc): """Handle the closing of a connection""" raise NotImplementedError def process_packet(self, data): """Process an incoming packet""" raise NotImplementedError def _get_userauth_request_packet(self, method, args): """Get packet data for a user authentication request""" # pylint: disable=no-self-use return b''.join((Byte(MSG_USERAUTH_REQUEST), String('user'), String('service'), String(method)) + args) def get_userauth_request_data(self, method, *args): """Get signature data for a user authentication request""" return String('') + self._get_userauth_request_packet(method, args) class _AuthClientStub(_AuthConnectionStub): """Stub class for client connection""" @classmethod def make_pair(cls, method, **kwargs): """Make a client and server connection pair to test authentication""" client_conn = cls(method, **kwargs) return client_conn, client_conn.get_peer() def __init__(self, method, gss_host=None, override_gss_mech=False, public_key_auth=False, client_key=None, client_cert=None, override_pk_ok=False, password_auth=False, password=None, password_change=NotImplemented, password_change_prompt=None, kbdint_auth=False, kbdint_submethods=None, kbdint_challenge=None, kbdint_response=None, success=False): super().__init__(_AuthServerStub(self, gss_host, override_gss_mech, public_key_auth, override_pk_ok, password_auth, password_change_prompt, kbdint_auth, kbdint_challenge, success), False) self._gss = GSSClient(gss_host, False) if gss_host else None self._client_key = client_key self._client_cert = client_cert self._password = password self._password_change = password_change self._password_changed = None self._kbdint_submethods = kbdint_submethods self._kbdint_response = kbdint_response self._auth_waiter = asyncio.Future() self._auth = lookup_client_auth(self, method) if self._auth is None: self.close() raise ValueError('Invalid auth method') def connection_lost(self, exc=None): """Handle the closing of a connection""" if exc: self._auth_waiter.set_exception(exc) self.close() def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() if pkttype == MSG_USERAUTH_FAILURE: _ = packet.get_namelist() partial_success = packet.get_boolean() packet.check_end() if partial_success: # pragma: no cover # Partial success not implemented yet self._auth.auth_succeeded() else: self._auth.auth_failed() self._auth_waiter.set_result((False, self._password_changed)) self._auth = None self._auth_waiter = None elif pkttype == MSG_USERAUTH_SUCCESS: packet.check_end() self._auth.auth_succeeded() self._auth_waiter.set_result((True, self._password_changed)) self._auth = None self._auth_waiter = None else: self._auth.process_packet(pkttype, packet) def get_auth_result(self): """Return the result of the authentication""" return (yield from self._auth_waiter) def try_next_auth(self): """Handle a request to move to another form of auth""" # Report that the current auth attempt failed self._auth_waiter.set_result((False, self._password_changed)) self._auth = None self._auth_waiter = None @asyncio.coroutine def send_userauth_request(self, method, *args, key=None): """Send a user authentication request""" packet = self._get_userauth_request_packet(method, args) if key: packet += String(key.sign(String('') + packet)) self.send_packet(packet) def get_gss_context(self): """Return the GSS context associated with this connection""" return self._gss def gss_mic_auth_requested(self): """Return whether to allow GSS MIC authentication or not""" return bool(self._gss) @asyncio.coroutine def public_key_auth_requested(self): """Return key to use for public key authentication""" if self._client_key: return SSHLocalKeyPair(self._client_key, self._client_cert) else: return None @asyncio.coroutine def password_auth_requested(self): """Return password to send for password authentication""" # pylint: disable=no-self-use return self._password @asyncio.coroutine def password_change_requested(self, prompt, lang): """Return old & new passwords for password change""" # pylint: disable=unused-argument if self._password_change is True: return 'password', 'new_password' else: return self._password_change def password_changed(self): """Handle a successful password change""" self._password_changed = True def password_change_failed(self): """Handle an unsuccessful password change""" self._password_changed = False @asyncio.coroutine def kbdint_auth_requested(self): """Return submethods to send for keyboard-interactive authentication""" return self._kbdint_submethods @asyncio.coroutine def kbdint_challenge_received(self, name, instruction, lang, prompts): """Return responses to keyboard-interactive challenge""" # pylint: disable=no-self-use,unused-argument if self._kbdint_response is True: return ('password',) else: return self._kbdint_response class _AuthServerStub(_AuthConnectionStub): """Stub class for server connection""" def __init__(self, peer=None, gss_host=None, override_gss_mech=False, public_key_auth=False, override_pk_ok=False, password_auth=False, password_change_prompt=None, kbdint_auth=False, kbdint_challenge=False, success=False): super().__init__(peer, True) self._gss = GSSServer(gss_host) if gss_host else None self._override_gss_mech = override_gss_mech self._public_key_auth = public_key_auth self._override_pk_ok = override_pk_ok self._password_auth = password_auth self._password_change_prompt = password_change_prompt self._kbdint_auth = kbdint_auth self._kbdint_challenge = kbdint_challenge self._success = success self._auth = None def connection_lost(self, exc=None): """Handle the closing of a connection""" if self._peer: self._peer.connection_lost(exc) self.close() def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() if pkttype == MSG_USERAUTH_REQUEST: _ = packet.get_string() # username _ = packet.get_string() # service method = packet.get_string() if self._auth: self._auth.cancel() if self._override_gss_mech: self.send_packet(Byte(MSG_USERAUTH_GSSAPI_RESPONSE), String('mismatch')) elif self._override_pk_ok: self.send_packet(Byte(MSG_USERAUTH_PK_OK), String(''), String('')) else: self._auth = lookup_server_auth(self, 'user', method, packet) else: self._auth.process_packet(pkttype, packet) def send_userauth_failure(self, partial_success): """Send a user authentication failure response""" self._auth = None self.send_packet(Byte(MSG_USERAUTH_FAILURE), NameList([]), Boolean(partial_success)) def send_userauth_success(self): """Send a user authentication success response""" self._auth = None self.send_packet(Byte(MSG_USERAUTH_SUCCESS)) def get_gss_context(self): """Return the GSS context associated with this connection""" return self._gss def gss_kex_auth_supported(self): """Return whether or not GSS key exchange authentication is supported""" # pylint: disable=no-self-use return bool(self._gss) def gss_mic_auth_supported(self): """Return whether or not GSS MIC authentication is supported""" return bool(self._gss) @asyncio.coroutine def validate_gss_principal(self, username, user_principal, host_principal): """Validate the GSS principal name for the specified user""" # pylint: disable=unused-argument return self._success def public_key_auth_supported(self): """Return whether or not public key authentication is supported""" return self._public_key_auth @asyncio.coroutine def validate_public_key(self, username, key_data, msg, signature): """Validate public key""" # pylint: disable=unused-argument return self._success def password_auth_supported(self): """Return whether or not password authentication is supported""" return self._password_auth @asyncio.coroutine def validate_password(self, username, password): """Validate password""" # pylint: disable=unused-argument if self._password_change_prompt: raise PasswordChangeRequired(self._password_change_prompt) else: return self._success @asyncio.coroutine def change_password(self, username, old_password, new_password): """Validate password""" # pylint: disable=unused-argument return self._success def kbdint_auth_supported(self): """Return whether or not keyboard-interactive authentication is supported""" return self._kbdint_auth @asyncio.coroutine def get_kbdint_challenge(self, username, lang, submethods): """Return a keyboard-interactive challenge""" # pylint: disable=unused-argument if self._kbdint_challenge is True: return '', '', '', (('Password:', False),) else: return self._kbdint_challenge @asyncio.coroutine def validate_kbdint_response(self, username, responses): """Validate keyboard-interactive responses""" # pylint: disable=unused-argument return self._success @patch_gss class _TestAuth(AsyncTestCase): """Unit tests for auth module""" @asyncio.coroutine def check_auth(self, method, expected_result, **kwargs): """Unit test authentication""" client_conn, server_conn = _AuthClientStub.make_pair(method, **kwargs) try: self.assertEqual((yield from client_conn.get_auth_result()), expected_result) finally: client_conn.close() server_conn.close() @asynctest def test_client_auth_methods(self): """Test client auth methods""" with self.subTest('Unknown client auth method'): with self.assertRaises(ValueError): _AuthClientStub.make_pair(b'xxx') @asynctest def test_server_auth_methods(self): """Test server auth methods""" with self.subTest('No auth methods'): server_conn = _AuthServerStub() self.assertEqual(get_server_auth_methods(server_conn), []) server_conn.close() with self.subTest('All auth methods'): gss_host = '1' if gss_available else None server_conn = _AuthServerStub(gss_host=gss_host, public_key_auth=True, password_auth=True, kbdint_auth=True) if gss_available: # pragma: no branch self.assertEqual(get_server_auth_methods(server_conn), [b'gssapi-keyex', b'gssapi-with-mic', b'publickey', b'keyboard-interactive', b'password']) else: # pragma: no cover self.assertEqual(get_server_auth_methods(server_conn), [b'publickey', b'keyboard-interactive', b'password']) server_conn.close() with self.subTest('Unknown auth method'): server_conn = _AuthServerStub() self.assertEqual(lookup_server_auth(server_conn, 'user', b'xxx', SSHPacket(b'')), None) server_conn.close() @asynctest def test_null_auth(self): """Unit test null authentication""" yield from self.check_auth(b'none', (False, None)) @unittest.skipUnless(gss_available, 'GSS not available') @asynctest def test_gss_auth(self): """Unit test GSS authentication""" with self.subTest('GSS with MIC auth not available'): yield from self.check_auth(b'gssapi-with-mic', (False, None)) for steps in range(4): with self.subTest('GSS with MIC auth available'): yield from self.check_auth(b'gssapi-with-mic', (True, None), gss_host=str(steps), success=True) gss_host = str(steps) + ',step_error' with self.subTest('GSS with MIC error', steps=steps): yield from self.check_auth(b'gssapi-with-mic', (False, None), gss_host=gss_host) with self.subTest('GSS with MIC error with token', steps=steps): yield from self.check_auth(b'gssapi-with-mic', (False, None), gss_host=gss_host + ',errtok') with self.subTest('GSS with MIC without integrity'): yield from self.check_auth(b'gssapi-with-mic', (True, None), gss_host='1,no_client_integrity,' + 'no_server_integrity', success=True) with self.subTest('GSS client integrity mismatch'): yield from self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1,no_client_integrity') with self.subTest('GSS server integrity mismatch'): yield from self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1,no_server_integrity') with self.subTest('GSS mechanism unknown'): yield from self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1,unknown_mech') with self.subTest('GSS mechanism mismatch'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1', override_gss_mech=True) @asynctest def test_publickey_auth(self): """Unit test public key authentication""" ckey = asyncssh.generate_private_key('ssh-rsa') cert = ckey.generate_user_certificate(ckey, 'name') with self.subTest('Public key auth not available'): yield from self.check_auth(b'publickey', (False, None)) with self.subTest('Untrusted key'): yield from self.check_auth(b'publickey', (False, None), client_key=ckey, public_key_auth=True) with self.subTest('Trusted key'): yield from self.check_auth(b'publickey', (True, None), client_key=ckey, public_key_auth=True, success=True) with self.subTest('Trusted certificate'): yield from self.check_auth(b'publickey', (True, None), client_key=ckey, client_cert=cert, public_key_auth=True, success=True) with self.subTest('Invalid PK_OK message'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'publickey', (False, None), client_key=ckey, public_key_auth=True, override_pk_ok=True) @asynctest def test_password_auth(self): """Unit test password authentication""" with self.subTest('Password auth not available'): yield from self.check_auth(b'password', (False, None)) with self.subTest('Invalid password'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'password', (False, None), password_auth=True, password=b'\xff') with self.subTest('Incorrect password'): yield from self.check_auth(b'password', (False, None), password_auth=True, password='password') with self.subTest('Correct password'): yield from self.check_auth(b'password', (True, None), password_auth=True, password='password', success=True) with self.subTest('Password change not available'): yield from self.check_auth(b'password', (False, None), password_auth=True, password='password', password_change_prompt='change') with self.subTest('Invalid password change prompt'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'password', (False, False), password_auth=True, password='password', password_change=True, password_change_prompt=b'\xff') with self.subTest('Password change failed'): yield from self.check_auth(b'password', (False, False), password_auth=True, password='password', password_change=True, password_change_prompt='change') with self.subTest('Password change succeeded'): yield from self.check_auth(b'password', (True, True), password_auth=True, password='password', password_change=True, password_change_prompt='change', success=True) @asynctest def test_kbdint_auth(self): """Unit test keyboard-interactive authentication""" with self.subTest('Kbdint auth not available'): yield from self.check_auth(b'keyboard-interactive', (False, None)) with self.subTest('No submethods'): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True) with self.subTest('Invalid submethods'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods=b'\xff') with self.subTest('No challenge'): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='') with self.subTest('Invalid challenge name'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=(b'\xff', '', '', ())) with self.subTest('Invalid challenge prompt'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=('', '', '', ((b'\xff', False),))) with self.subTest('No response'): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True) with self.subTest('Invalid response'): with self.assertRaises(DisconnectError): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True, kbdint_response=(b'\xff',)) with self.subTest('Incorrect response'): yield from self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True, kbdint_response=True) with self.subTest('Correct response'): yield from self.check_auth(b'keyboard-interactive', (True, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True, kbdint_response=True, success=True) asyncssh-1.11.1/tests/test_auth_keys.py000066400000000000000000000206361320320510200201700ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for matching against authorized_keys file""" import unittest from unittest.mock import patch import asyncssh from .util import TempDirTestCase, x509_available class _TestAuthorizedKeys(TempDirTestCase): """Unit tests for auth_keys module""" keylist = [] imported_keylist = [] certlist = [] imported_certlist = [] @classmethod def setUpClass(cls): """Create public keys needed for test""" super().setUpClass() for i in range(3): key = asyncssh.generate_private_key('ssh-rsa') cls.keylist.append(key.export_public_key().decode('ascii')) cls.imported_keylist.append(key.convert_to_public()) if x509_available: # pragma: no branch subject = 'CN=cert%s' % i cert = key.generate_x509_user_certificate(key, subject) cls.certlist.append(cert.export_certificate().decode('ascii')) cls.imported_certlist.append(cert) def build_keys(self, keys, x509=False, from_file=False): """Build and import a list of authorized keys""" auth_keys = '# Comment line\n # Comment line with whitespace\n\n' for options in keys: options = options + ' ' if options else '' keynum = 1 if 'cert-authority' in options else 0 key_or_cert = (self.certlist if x509 else self.keylist)[keynum] auth_keys += '%s%s' % (options, key_or_cert) if from_file: with open('authorized_keys', 'w') as f: f.write(auth_keys) return asyncssh.read_authorized_keys('authorized_keys') else: return asyncssh.import_authorized_keys(auth_keys) def match_keys(self, tests, x509=False): """Match against authorized keys""" def getnameinfo(sockaddr, flags): """Mock reverse DNS lookup of client address""" # pylint: disable=unused-argument host, port = sockaddr if host == '127.0.0.1': return ('localhost', port) else: return sockaddr with patch('socket.getnameinfo', getnameinfo): for keys, matches in tests: auth_keys = self.build_keys(keys, x509) for (msg, keynum, client_addr, cert_principals, match) in matches: with self.subTest(msg, x509=x509): if x509: result, trusted_cert = auth_keys.validate_x509( self.imported_certlist[keynum], client_addr) if (trusted_cert and trusted_cert.subject != self.imported_certlist[keynum].subject): result = None else: result = auth_keys.validate( self.imported_keylist[keynum], client_addr, cert_principals, keynum == 1) self.assertEqual(result is not None, match) def test_matches(self): """Test authorized keys matching""" tests = ( ((None, 'cert-authority'), (('Match key or cert', 0, '1.2.3.4', None, True), ('Match CA key or cert', 1, '1.2.3.4', None, True), ('No match', 2, '1.2.3.4', None, False))), (('from="1.2.3.4"',), (('Match IP', 0, '1.2.3.4', None, True),)), (('from="1.2.3.0/24,!1.2.3.5"',), (('Match subnet', 0, '1.2.3.4', None, True), ('Exclude IP', 0, '1.2.3.5', None, False))), (('from="localhost*"',), (('Match host name', 0, '127.0.0.1', None, True),)), (('from="1.2.3.*,!1.2.3.5*"',), (('Match host pattern', 0, '1.2.3.4', None, True), ('Exclude host pattern', 0, '1.2.3.5', None, False))), (('principals="cert*,!cert1"',), (('Match principal', 0, '1.2.3.4', ['cert0'], True),)), (('cert-authority,principals="cert*,!cert1"',), (('Exclude principal', 1, '1.2.3.4', ['cert1'], False),)) ) self.match_keys(tests) if x509_available: # pragma: no branch self.match_keys(tests, x509=True) def test_options(self): """Test authorized keys returned option values""" tests = ( ('Command', 'command="ls abc"', {'command': 'ls abc'}), ('PermitOpen', 'permitopen="xxx:123"', {'permitopen': {('xxx', 123)}}), ('PermitOpen IPv6 address', 'permitopen="[fe80::1]:123"', {'permitopen': {('fe80::1', 123)}}), ('PermitOpen wildcard port', 'permitopen="xxx:*"', {'permitopen': {('xxx', None)}}), ('Unknown option', 'foo=abc,foo=def', {'foo': ['abc', 'def']}), ('Escaped value', 'environment="FOO=\\"xxx\\""', {'environment': {'FOO': '"xxx"'}}) ) for msg, options, expected in tests: with self.subTest(msg): auth_keys = self.build_keys([options]) result = auth_keys.validate(self.imported_keylist[0], '1.2.3.4', None, False) self.assertEqual(result, expected) def test_file(self): """Test reading authorized keys from file""" self.build_keys([None], from_file=True) @unittest.skipUnless(x509_available, 'X.509 not available') def test_subject_match(self): """Test match on X.509 subject name""" auth_keys = asyncssh.import_authorized_keys( 'x509v3-ssh-rsa subject=CN=cert0\n') result, _ = auth_keys.validate_x509( self.imported_certlist[0], '1.2.3.4') self.assertIsNotNone(result) @unittest.skipUnless(x509_available, 'X.509 not available') def test_subject_option_match(self): """Test match on X.509 subject in options""" auth_keys = asyncssh.import_authorized_keys( 'subject=CN=cert0 ' + self.certlist[0]) result, _ = auth_keys.validate_x509( self.imported_certlist[0], '1.2.3.4') self.assertIsNotNone(result) @unittest.skipUnless(x509_available, 'X.509 not available') def test_subject_option_mismatch(self): """Test failed match on X.509 subject in options""" auth_keys = asyncssh.import_authorized_keys( 'subject=CN=cert1 ' + self.certlist[0]) result, _ = auth_keys.validate_x509( self.imported_certlist[0], '1.2.3.4') self.assertIsNone(result) @unittest.skipUnless(x509_available, 'X.509 not available') def test_cert_authority_with_subject(self): """Test error when cert-authority is used with subject""" with self.assertRaises(ValueError): asyncssh.import_authorized_keys( 'cert-authority x509v3-sign-rsa subject=CN=cert0\n') @unittest.skipUnless(x509_available, 'X.509 not available') def test_non_root_ca(self): """Test error on non-root X.509 CA""" key = asyncssh.generate_private_key('ssh-rsa') cert = key.generate_x509_user_certificate(key, 'CN=a', 'CN=b') data = 'cert-authority ' + cert.export_certificate().decode('ascii') with self.assertRaises(ValueError): asyncssh.import_authorized_keys(data) def test_errors(self): """Test various authorized key parsing errors""" tests = ( ('Bad key', 'xxx\n'), ('Unbalanced quote', 'xxx"\n'), ('Unbalanced backslash', 'xxx\\\n'), ('Missing option name', '=xxx\n'), ('Environment missing equals', 'environment="FOO"\n'), ('Environment missing variable name', 'environment="=xxx"\n'), ('PermitOpen missing colon', 'permitopen="xxx"\n'), ('PermitOpen non-integer port', 'permitopen="xxx:yyy"\n') ) for msg, data in tests: with self.subTest(msg): with self.assertRaises(ValueError): asyncssh.import_authorized_keys(data) asyncssh-1.11.1/tests/test_channel.py000066400000000000000000001373401320320510200176050ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH channel API""" import asyncio import tempfile from unittest.mock import patch import asyncssh from asyncssh.constants import DEFAULT_LANG, MSG_USERAUTH_REQUEST from asyncssh.constants import MSG_CHANNEL_OPEN_CONFIRMATION from asyncssh.constants import MSG_CHANNEL_OPEN_FAILURE from asyncssh.constants import MSG_CHANNEL_WINDOW_ADJUST from asyncssh.constants import MSG_CHANNEL_DATA from asyncssh.constants import MSG_CHANNEL_EXTENDED_DATA from asyncssh.constants import MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE from asyncssh.constants import MSG_CHANNEL_SUCCESS from asyncssh.packet import Byte, String, UInt32 from asyncssh.public_key import CERT_TYPE_USER from asyncssh.stream import SSHTCPStreamSession, SSHUNIXStreamSession from .server import Server, ServerTestCase from .util import asynctest, echo, make_certificate PTY_OP_PARTIAL = 158 PTY_OP_NO_END = 159 class _ClientChannel(asyncssh.SSHClientChannel): """Patched SSH client channel for unit testing""" def _send_request(self, request, *args, want_reply=False): """Send a channel request""" if request == b'env' and args[1] == String('invalid'): args = args[:1] + (String(b'\xff'),) elif request == b'pty-req': if args[5][-6:-5] == Byte(PTY_OP_PARTIAL): args = args[:5] + (String(args[5][4:-5]),) elif args[5][-6:-5] == Byte(PTY_OP_NO_END): args = args[:5] + (String(args[5][4:-6]),) super()._send_request(request, *args, want_reply=want_reply) def send_request(self, request, *args): """Send a custom request (for unit testing)""" self._send_request(request, *args) @asyncio.coroutine def make_request(self, request, *args): """Make a custom request (for unit testing)""" yield from self._make_request(request, *args) class _ClientSession(asyncssh.SSHClientSession): """Unit test SSH client session""" def __init__(self): self._chan = None self.recv_buf = {None: [], asyncssh.EXTENDED_DATA_STDERR: []} self.xon_xoff = None self.exit_status = None self.exit_signal_msg = None self.exc = None def connection_made(self, chan): """Handle connection open""" self._chan = chan def connection_lost(self, exc): """Handle connection close""" self.exc = exc self._chan = None def data_received(self, data, datatype): """Handle data from the channel""" self.recv_buf[datatype].append(data) def xon_xoff_requested(self, client_can_do): """Handle request to enable/disable XON/XOFF flow control""" self.xon_xoff = client_can_do def exit_status_received(self, status): """Handle remote exit status""" # pylint: disable=unused-argument self.exit_status = status def exit_signal_received(self, signal, core_dumped, msg, lang): """Handle remote exit signal""" # pylint: disable=unused-argument self.exit_signal_msg = msg @asyncio.coroutine def _create_session(conn, command=None, *, subsystem=None, **kwargs): """Create a client session""" return (yield from conn.create_session(_ClientSession, command, subsystem=subsystem, **kwargs)) class _ServerChannel(asyncssh.SSHServerChannel): """Patched SSH server channel class for unit testing""" def _send_request(self, request, *args, want_reply=False): """Send a channel request""" if request == b'exit-signal': if args[0] == String('invalid'): args = (String(b'\xff'),) + args[1:] if args[3] == String('invalid'): args = args[:3] + (String(b'\xff'),) super()._send_request(request, *args, want_reply=want_reply) def _process_delayed_request(self, packet): """Process a request that delays before responding""" packet.check_end() asyncio.get_event_loop().call_later(0.1, self._report_response, True) def send_packet(self, pkttype, *args): """Send a packet for unit testing (bypassing state checks)""" self._send_packet(pkttype, *args) @asyncio.coroutine def open_session(self): """Attempt to open a session on the client""" return (yield from self._open(b'session')) class _EchoServerSession(asyncssh.SSHServerSession): """A shell session which echos data from stdin to stdout/stderr""" def __init__(self): self._chan = None self._pty_ok = True def connection_made(self, chan): """Handle session open""" self._chan = chan username = self._chan.get_extra_info('username') if username == 'close': self._chan.close() elif username == 'no_pty': self._pty_ok = False elif username == 'task_error': raise RuntimeError('Exception handler test') def pty_requested(self, term_type, term_size, term_modes): """Handle pseudo-terminal request""" return self._pty_ok def shell_requested(self): """Handle shell request""" return True def data_received(self, data, datatype): """Handle data from the channel""" self._chan.write(data[:1]) self._chan.writelines([data[1:]]) self._chan.write_stderr(data[:1]) self._chan.writelines_stderr([data[1:]]) def eof_received(self): """Handle EOF on the channel""" self._chan.write_eof() self._chan.close() class _ChannelServer(Server): """Server for testing the AsyncSSH channel API""" def _begin_session(self, stdin, stdout, stderr): """Begin processing a new session""" # pylint: disable=too-many-statements action = stdin.channel.get_command() or stdin.channel.get_subsystem() if not action: action = 'echo' if action == 'echo': yield from echo(stdin, stdout, stderr) elif action == 'conn_close': yield from stdin.read(1) stdout.write('\n') self._conn.close() elif action == 'close': yield from stdin.read(1) stdout.write('\n') elif action == 'agent': agent = yield from asyncssh.connect_agent(self._conn) if agent: stdout.write(str(len((yield from agent.get_keys()))) + '\n') agent.close() else: stdout.channel.exit(1) elif action == 'agent_sock': agent_path = stdin.channel.get_agent_path() if agent_path: agent = yield from asyncssh.connect_agent(agent_path) stdout.write(str(len((yield from agent.get_keys()))) + '\n') agent.close() else: stdout.channel.exit(1) elif action == 'rejected_agent': agent_path = stdin.channel.get_agent_path() stdout.write(str(bool(agent_path)) + '\n') chan = self._conn.create_agent_channel() try: yield from chan.open(SSHUNIXStreamSession) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_session': chan = _ServerChannel(self._conn, asyncio.get_event_loop(), False, False, 0, None, 1, 32768) try: yield from chan.open_session() except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_tcpip_direct': chan = self._conn.create_tcp_channel() try: yield from chan.connect(SSHTCPStreamSession, '', 0, '', 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'unknown_tcpip_listener': chan = self._conn.create_tcp_channel() try: yield from chan.accept(SSHTCPStreamSession, 'xxx', 0, '', 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'invalid_tcpip_listener': chan = self._conn.create_tcp_channel() try: yield from chan.accept(SSHTCPStreamSession, b'\xff', 0, '', 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_unix_direct': chan = self._conn.create_unix_channel() try: yield from chan.connect(SSHUNIXStreamSession, '') except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'unknown_unix_listener': chan = self._conn.create_unix_channel() try: yield from chan.accept(SSHUNIXStreamSession, 'xxx') except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'invalid_unix_listener': chan = self._conn.create_unix_channel() try: yield from chan.accept(SSHUNIXStreamSession, b'\xff') except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'late_auth_banner': try: self._conn.send_auth_banner('auth banner') except OSError: stdin.channel.exit(1) elif action == 'invalid_open_confirm': stdin.channel.send_packet(MSG_CHANNEL_OPEN_CONFIRMATION, UInt32(0), UInt32(0), UInt32(0)) elif action == 'invalid_open_failure': stdin.channel.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(0), String(''), String('')) elif action == 'env': value = stdin.channel.get_environment().get('TEST', '') stdout.write(value + '\n') elif action == 'term': chan = stdin.channel info = str((chan.get_terminal_type(), chan.get_terminal_size(), chan.get_terminal_mode(asyncssh.PTY_OP_OSPEED))) stdout.write(info + '\n') elif action == 'xon_xoff': stdin.channel.set_xon_xoff(True) elif action == 'no_xon_xoff': stdin.channel.set_xon_xoff(False) elif action == 'signals': try: yield from stdin.readline() except asyncssh.BreakReceived as exc: stdin.channel.exit_with_signal('ABRT', False, str(exc.msec)) except asyncssh.SignalReceived as exc: stdin.channel.exit_with_signal('ABRT', False, exc.signal) except asyncssh.TerminalSizeChanged as exc: size = (exc.width, exc.height, exc.pixwidth, exc.pixheight) stdin.channel.exit_with_signal('ABRT', False, str(size)) elif action == 'exit_status': stdin.channel.exit(1) elif action == 'closed_status': stdin.channel.close() stdin.channel.exit(1) elif action == 'exit_signal': stdin.channel.exit_with_signal('ABRT', False, 'exit_signal') elif action == 'closed_signal': stdin.channel.close() stdin.channel.exit_with_signal('ABRT', False, 'closed_signal') elif action == 'invalid_exit_signal': stdin.channel.exit_with_signal('invalid') elif action == 'invalid_exit_lang': stdin.channel.exit_with_signal('ABRT', False, '', 'invalid') elif action == 'window_after_close': stdin.channel.send_packet(MSG_CHANNEL_CLOSE) stdin.channel.send_packet(MSG_CHANNEL_WINDOW_ADJUST, UInt32(0)) elif action == 'empty_data': stdin.channel.send_packet(MSG_CHANNEL_DATA, String('')) elif action == 'partial_unicode': data = '\xff\xff'.encode('utf-8') stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[:3])) stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[3:])) elif action == 'partial_unicode_at_eof': data = '\xff\xff'.encode('utf-8') stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[:3])) elif action == 'unicode_error': stdin.channel.send_packet(MSG_CHANNEL_DATA, String(b'\xff')) elif action == 'data_past_window': stdin.channel.send_packet(MSG_CHANNEL_DATA, String(2*1025*1024*'\0')) elif action == 'data_after_eof': stdin.channel.send_packet(MSG_CHANNEL_EOF) stdout.write('xxx') elif action == 'data_after_close': yield from asyncio.sleep(0.1) stdout.write('xxx') elif action == 'ext_data_after_eof': stdin.channel.send_packet(MSG_CHANNEL_EOF) stdin.channel.write_stderr('xxx') elif action == 'invalid_datatype': stdin.channel.send_packet(MSG_CHANNEL_EXTENDED_DATA, UInt32(255), String('')) elif action == 'double_eof': stdin.channel.send_packet(MSG_CHANNEL_EOF) stdin.channel.write_eof() elif action == 'double_close': yield from asyncio.sleep(0.1) stdout.write('xxx') stdin.channel.send_packet(MSG_CHANNEL_CLOSE) elif action == 'request_after_close': stdin.channel.send_packet(MSG_CHANNEL_CLOSE) stdin.channel.exit(1) elif action == 'unexpected_auth': self._conn.send_packet(Byte(MSG_USERAUTH_REQUEST), String('guest'), String('ssh-connection'), String('none')) elif action == 'invalid_response': stdin.channel.send_packet(MSG_CHANNEL_SUCCESS) else: stdin.channel.exit(255) stdin.channel.close() yield from stdin.channel.wait_closed() def begin_auth(self, username): """Handle client authentication request""" return username not in {'guest', 'conn_close', 'close', 'echo', 'no_channels', 'no_pty', 'task_error'} def session_requested(self): """Handle a request to create a new session""" username = self._conn.get_extra_info('username') with patch('asyncssh.connection.SSHServerChannel', _ServerChannel): channel = self._conn.create_server_channel() if username == 'conn_close': self._conn.close() return False elif username in {'close', 'echo', 'no_pty', 'task_error'}: return (channel, _EchoServerSession()) elif username != 'no_channels': return (channel, self._begin_session) else: return False class _TestChannel(ServerTestCase): """Unit tests for AsyncSSH channel API""" # pylint: disable=too-many-public-methods @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server( _ChannelServer, authorized_client_keys='authorized_keys')) @asyncio.coroutine def _check_action(self, command, expected_result): """Run a command on a remote session and check for a specific result""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, command) yield from chan.wait_closed() self.assertEqual(session.exit_status, expected_result) yield from conn.wait_closed() @asyncio.coroutine def _check_session(self, conn, command=None, *, subsystem=None, large_block=False, **kwargs): """Open a session and test if an input line is echoed back""" chan, session = yield from _create_session(conn, command, subsystem=subsystem, *kwargs) if large_block: data = 4 * [1025*1024*'\0'] else: data = [str(id(self))] self.assertTrue(chan.can_write_eof()) chan.writelines(data) chan.write_eof() yield from chan.wait_closed() data = ''.join(data) for buf in session.recv_buf.values(): self.assertEqual(data, ''.join(buf)) chan.close() @asynctest def test_shell(self): """Test starting a shell""" with (yield from self.connect(username='echo')) as conn: yield from self._check_session(conn) yield from conn.wait_closed() @asynctest def test_shell_failure(self): """Test failure to start a shell""" with (yield from self.connect(username='no_channels')) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn) yield from conn.wait_closed() @asynctest def test_shell_internal_error(self): """Test internal error in callback to start a shell""" with (yield from self.connect(username='task_error')) as conn: with self.assertRaises((OSError, asyncssh.DisconnectError)): yield from _create_session(conn) yield from conn.wait_closed() @asynctest def test_shell_large_block(self): """Test starting a shell and sending a large block of data""" with (yield from self.connect(username='echo')) as conn: yield from self._check_session(conn, large_block=True) yield from conn.wait_closed() @asynctest def test_exec(self): """Test execution of a remote command""" with (yield from self.connect()) as conn: yield from self._check_session(conn, 'echo') yield from conn.wait_closed() @asynctest def test_forced_exec(self): """Test execution of a forced remote command""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], options={'force-command': String('echo')}) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: yield from self._check_session(conn) yield from conn.wait_closed() @asynctest def test_invalid_exec(self): """Test execution of an invalid remote command""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn, b'\xff') yield from conn.wait_closed() @asynctest def test_exec_failure(self): """Test failure to execute a remote command""" with (yield from self.connect(username='no_channels')) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn, 'echo') yield from conn.wait_closed() @asynctest def test_subsystem(self): """Test starting a subsystem""" with (yield from self.connect()) as conn: yield from self._check_session(conn, subsystem='echo') yield from conn.wait_closed() @asynctest def test_invalid_subsystem(self): """Test starting an invalid subsystem""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn, subsystem=b'\xff') yield from conn.wait_closed() @asynctest def test_subsystem_failure(self): """Test failure to start a subsystem""" with (yield from self.connect(username='no_channels')) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn, subsystem='echo') yield from conn.wait_closed() @asynctest def test_conn_close_during_startup(self): """Test connection close during channel startup""" with (yield from self.connect(username='conn_close')) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn) yield from conn.wait_closed() @asynctest def test_close_during_startup(self): """Test channel close during startup""" with (yield from self.connect(username='close')) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn) yield from conn.wait_closed() @asynctest def test_inbound_conn_close_while_read_paused(self): """Test inbound connection close while reading is paused""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'conn_close') chan.pause_reading() chan.write('\n') yield from asyncio.sleep(0.1) conn.close() yield from asyncio.sleep(0.1) chan.resume_reading() yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_outbound_conn_close_while_read_paused(self): """Test outbound connection close while reading is paused""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'close') chan.pause_reading() chan.write('\n') yield from asyncio.sleep(0.1) conn.close() yield from asyncio.sleep(0.1) chan.resume_reading() yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_invalid_open_confirmation(self): """Test receiving an open confirmation on already open channel""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'invalid_open_confirm') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_invalid_open_failure(self): """Test receiving an open failure on already open channel""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'invalid_open_failure') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_invalid_channel_request(self): """Test sending non-ASCII channel request""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn) with self.assertRaises(asyncssh.DisconnectError): yield from chan.make_request('\xff') yield from conn.wait_closed() @asynctest def test_delayed_channel_request(self): """Test queuing channel requests with delayed response""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn) chan.send_request(b'delayed') chan.send_request(b'delayed') yield from conn.wait_closed() @asynctest def test_invalid_channel_response(self): """Test receiving response for non-existent channel request""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'invalid_response') chan.close() yield from conn.wait_closed() @asynctest def test_already_open(self): """Test connect on an already open channel""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn) with self.assertRaises(OSError): yield from chan.create(None, None, None, {}, None, None, None, False, None, None, False, False) chan.close() yield from conn.wait_closed() @asynctest def test_write_buffer(self): """Test setting write buffer limits""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn) chan.set_write_buffer_limits() chan.set_write_buffer_limits(low=8192) chan.set_write_buffer_limits(high=32768) chan.set_write_buffer_limits(32768, 8192) with self.assertRaises(ValueError): chan.set_write_buffer_limits(8192, 32768) self.assertEqual(chan.get_write_buffer_size(), 0) chan.close() yield from conn.wait_closed() @asynctest def test_empty_write(self): """Test writing an empty block of data""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn) chan.write('') chan.close() yield from conn.wait_closed() @asynctest def test_invalid_write_extended(self): """Test writing using an invalid extended data type""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn) with self.assertRaises(OSError): chan.write('test', -1) yield from conn.wait_closed() @asynctest def test_unneeded_resume_reading(self): """Test resume reading when not paused""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn) chan.resume_reading() chan.close() yield from conn.wait_closed() @asynctest def test_agent_forwarding(self): """Test SSH agent forwarding""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') with (yield from self.connect(username='ckey', agent_forwarding=True)) as conn: chan, session = yield from _create_session(conn, 'agent') yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '3\n') chan, session = yield from _create_session(conn, 'agent') yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '3\n') yield from conn.wait_closed() @asynctest def test_agent_forwarding_sock(self): """Test SSH agent forwarding via UNIX domain socket""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') with (yield from self.connect(username='ckey', agent_forwarding=True)) as conn: chan, session = yield from _create_session(conn, 'agent_sock') yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '3\n') yield from conn.wait_closed() @asynctest def test_rejected_session(self): """Test receiving inbound session request""" yield from self._check_action('rejected_session', 1) @asynctest def test_rejected_tcpip_direct(self): """Test receiving inbound direct TCP/IP connection""" yield from self._check_action('rejected_tcpip_direct', 1) @asynctest def test_unknown_tcpip_listener(self): """Test receiving connection on unknown TCP/IP listener""" yield from self._check_action('unknown_tcpip_listener', 1) @asynctest def test_invalid_tcpip_listener(self): """Test receiving connection on invalid TCP/IP listener path""" yield from self._check_action('invalid_tcpip_listener', None) @asynctest def test_rejected_unix_direct(self): """Test receiving inbound direct UNIX connection""" yield from self._check_action('rejected_unix_direct', 1) @asynctest def test_unknown_unix_listener(self): """Test receiving connection on unknown UNIX listener""" yield from self._check_action('unknown_unix_listener', 1) @asynctest def test_invalid_unix_listener(self): """Test receiving connection on invalid UNIX listener path""" yield from self._check_action('invalid_unix_listener', None) @asynctest def test_agent_forwarding_failure(self): """Test failure of SSH agent forwarding""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-agent-forwarding': ''}) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)], agent_forwarding=True)) as conn: chan, session = yield from _create_session(conn, 'agent') yield from chan.wait_closed() self.assertEqual(session.exit_status, 1) yield from conn.wait_closed() @asynctest def test_agent_forwarding_sock_failure(self): """Test failure to create SSH agent forwarding socket""" tempfile.tempdir = 'xxx' with (yield from self.connect(username='ckey', agent_forwarding=True)) as conn: chan, session = yield from _create_session(conn, 'agent_sock') yield from chan.wait_closed() self.assertEqual(session.exit_status, 1) yield from conn.wait_closed() tempfile.tempdir = None @asynctest def test_agent_forwarding_not_offered(self): """Test SSH agent forwarding not offered by client""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'agent') yield from chan.wait_closed() self.assertEqual(session.exit_status, 1) yield from conn.wait_closed() @asynctest def test_agent_forwarding_rejected(self): """Test rejection of SSH agent forwarding by client""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'rejected_agent') yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'False\n') self.assertEqual(session.exit_status, 1) yield from conn.wait_closed() @asynctest def test_terminal_info(self): """Test sending terminal information""" modes = {asyncssh.PTY_OP_OSPEED: 9600} with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'term', term_type='ansi', term_size=(80, 24), term_modes=modes) yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, "('ansi', (80, 24, 0, 0), 9600)\r\n") yield from conn.wait_closed() @asynctest def test_terminal_full_size(self): """Test sending terminal information with full size""" modes = {asyncssh.PTY_OP_OSPEED: 9600} with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'term', term_type='ansi', term_size=(80, 24, 480, 240), term_modes=modes) yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, "('ansi', (80, 24, 480, 240), 9600)\r\n") yield from conn.wait_closed() @asynctest def test_invalid_terminal_size(self): """Test sending invalid terminal size""" with (yield from self.connect()) as conn: with self.assertRaises(ValueError): yield from _create_session(conn, 'term', term_type='ansi', term_size=(0, 0, 0)) yield from conn.wait_closed() @asynctest def test_invalid_terminal_modes(self): """Test sending invalid terminal modes""" modes = {asyncssh.PTY_OP_RESERVED: 0} with (yield from self.connect()) as conn: with self.assertRaises(ValueError): yield from _create_session(conn, 'term', term_type='ansi', term_modes=modes) yield from conn.wait_closed() @asynctest def test_pty_disallowed_by_cert(self): """Test rejection of pty request by certificate""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-pty': ''}) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn, 'term', term_type='ansi') yield from conn.wait_closed() @asynctest def test_pty_disallowed_by_session(self): """Test rejection of pty request by session""" with (yield from self.connect(username='no_pty')) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_session(conn, 'term', term_type='ansi') yield from conn.wait_closed() @asynctest def test_invalid_term_type(self): """Test requesting an invalid terminal type""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.DisconnectError): yield from _create_session(conn, term_type=b'\xff') yield from conn.wait_closed() @asynctest def test_term_modes_missing_end(self): """Test sending terminal modes without PTY_OP_END""" modes = {asyncssh.PTY_OP_OSPEED: 9600, PTY_OP_NO_END: 0} with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'term', term_type='ansi', term_modes=modes) yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, "('ansi', (0, 0, 0, 0), 9600)\r\n") yield from conn.wait_closed() @asynctest def test_term_modes_incomplete(self): """Test sending terminal modes with incomplete value""" modes = {asyncssh.PTY_OP_OSPEED: 9600, PTY_OP_PARTIAL: 0} with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.DisconnectError): yield from _create_session(conn, 'term', term_type='ansi', term_modes=modes) yield from conn.wait_closed() @asynctest def test_env(self): """Test sending environment""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'env', env={'TEST': 'test'}) yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') yield from conn.wait_closed() @asynctest def test_invalid_env(self): """Test sending invalid environment""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): with (yield from self.connect()) as conn: chan, session = yield from _create_session( conn, 'env', env={'TEST': 'invalid'}) yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '\n') yield from conn.wait_closed() @asynctest def test_xon_xoff_enable(self): """Test enabling XON/XOFF flow control""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'xon_xoff') yield from chan.wait_closed() self.assertEqual(session.xon_xoff, True) yield from conn.wait_closed() @asynctest def test_xon_xoff_disable(self): """Test disabling XON/XOFF flow control""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'no_xon_xoff') yield from chan.wait_closed() self.assertEqual(session.xon_xoff, False) yield from conn.wait_closed() @asynctest def test_break(self): """Test sending a break""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'signals') chan.send_break(1000) yield from chan.wait_closed() self.assertEqual(session.exit_signal_msg, '1000') yield from conn.wait_closed() @asynctest def test_signal(self): """Test sending a signal""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'signals') chan.send_signal('HUP') yield from chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'HUP') yield from conn.wait_closed() @asynctest def test_terminate(self): """Test sending a terminate signal""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'signals') chan.terminate() yield from chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'TERM') yield from conn.wait_closed() @asynctest def test_kill(self): """Test sending a kill signal""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'signals') chan.kill() yield from chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'KILL') yield from conn.wait_closed() @asynctest def test_invalid_signal(self): """Test sending an invalid signal""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'signals') chan.send_signal(b'\xff') chan.write('\n') yield from chan.wait_closed() self.assertEqual(session.exit_status, None) yield from conn.wait_closed() @asynctest def test_terminal_size_change(self): """Test sending terminal size change""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'signals', term_type='ansi') chan.change_terminal_size(80, 24) yield from chan.wait_closed() self.assertEqual(session.exit_signal_msg, '(80, 24, 0, 0)') yield from conn.wait_closed() @asynctest def test_exit_status(self): """Test receiving exit status""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'exit_status') yield from chan.wait_closed() self.assertEqual(session.exit_status, 1) self.assertEqual(chan.get_exit_status(), 1) self.assertIsNone(chan.get_exit_signal()) yield from conn.wait_closed() @asynctest def test_exit_status_after_close(self): """Test delivery of exit status after remote close""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'closed_status') yield from chan.wait_closed() self.assertIsNone(session.exit_status) self.assertIsNone(chan.get_exit_status()) self.assertIsNone(chan.get_exit_signal()) yield from conn.wait_closed() @asynctest def test_exit_signal(self): """Test receiving exit signal""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'exit_signal') yield from chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'exit_signal') self.assertEqual(chan.get_exit_status(), -1) self.assertEqual(chan.get_exit_signal(), ('ABRT', False, 'exit_signal', DEFAULT_LANG)) yield from conn.wait_closed() @asynctest def test_exit_signal_after_close(self): """Test delivery of exit signal after remote close""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'closed_signal') yield from chan.wait_closed() self.assertIsNone(session.exit_signal_msg) self.assertIsNone(chan.get_exit_status()) self.assertIsNone(chan.get_exit_signal()) yield from conn.wait_closed() @asynctest def test_invalid_exit_signal(self): """Test delivery of invalid exit signal""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'invalid_exit_signal') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_invalid_exit_lang(self): """Test delivery of invalid exit signal language""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'invalid_exit_lang') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_window_adjust_after_eof(self): """Test receiving window adjust after EOF""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'window_after_close') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_empty_data(self): """Test receiving empty data packet""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'empty_data') chan.close() yield from conn.wait_closed() @asynctest def test_partial_unicode(self): """Test receiving Unicode data spread across two packets""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'partial_unicode') yield from chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '\xff\xff') yield from conn.wait_closed() @asynctest def test_partial_unicode_at_eof(self): """Test receiving partial Unicode data and then EOF""" with (yield from self.connect()) as conn: chan, session = yield from _create_session( conn, 'partial_unicode_at_eof') yield from chan.wait_closed() self.assertIsInstance(session.exc, asyncssh.DisconnectError) yield from conn.wait_closed() @asynctest def test_unicode_error(self): """Test receiving bad Unicode data""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.DisconnectError): yield from _create_session(conn, 'unicode_error') yield from conn.wait_closed() @asynctest def test_data_past_window(self): """Test receiving a data packet past the advertised window""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'data_past_window') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_data_after_eof(self): """Test receiving data after EOF""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'data_after_eof') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_data_after_close(self): """Test receiving data after close""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'data_after_close') chan.write(4*1025*1024*'\0') chan.close() yield from asyncio.sleep(0.2) yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_extended_data_after_eof(self): """Test receiving extended data after EOF""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'ext_data_after_eof') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_invalid_datatype(self): """Test receiving data with invalid data type""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'invalid_datatype') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_double_eof(self): """Test receiving two EOF messages""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'double_eof') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_double_close(self): """Test receiving two close messages""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'double_close') chan.pause_reading() yield from asyncio.sleep(0.2) chan.resume_reading() yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_request_after_close(self): """Test receiving a channel request after a close""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'request_after_close') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_late_auth_banner(self): """Test server sending authentication banner after auth completes""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'late_auth_banner') yield from chan.wait_closed() self.assertEqual(session.exit_status, 1) yield from conn.wait_closed() @asynctest def test_unexpected_userauth_request(self): """Test userauth request sent to client""" with (yield from self.connect()) as conn: chan, _ = yield from _create_session(conn, 'unexpected_auth') yield from chan.wait_closed() yield from conn.wait_closed() @asynctest def test_unknown_action(self): """Test unknown action""" with (yield from self.connect()) as conn: chan, session = yield from _create_session(conn, 'unknown') yield from chan.wait_closed() self.assertEqual(session.exit_status, 255) yield from conn.wait_closed() asyncssh-1.11.1/tests/test_cipher.py000066400000000000000000000076601320320510200174500ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for symmetric key encryption""" import os import unittest from asyncssh.cipher import get_encryption_algs, get_encryption_params from asyncssh.cipher import get_cipher from .util import libnacl_available class _TestCipher(unittest.TestCase): """Unit tests for cipher module""" def test_encryption_algs(self): """Unit test encryption algorithms""" for alg in get_encryption_algs(): with self.subTest(alg=alg): keysize, ivsize, blocksize, mode = get_encryption_params(alg) key = os.urandom(keysize) iv = os.urandom(ivsize) data = os.urandom(32*blocksize) enc_cipher = get_cipher(alg, key, iv) dec_cipher = get_cipher(alg, key, iv) badkey = bytearray(key) badkey[-1] ^= 0xff bad_cipher = get_cipher(alg, bytes(badkey), iv) hdr = os.urandom(4) if mode == 'chacha': nonce = os.urandom(8) enchdr = enc_cipher.crypt_len(hdr, nonce) encdata, mac = enc_cipher.encrypt_and_sign(hdr, data, nonce) dechdr = dec_cipher.crypt_len(enchdr, nonce) decdata = dec_cipher.verify_and_decrypt(dechdr, encdata, nonce, mac) badhdr = bad_cipher.crypt_len(enchdr, nonce) baddata = bad_cipher.verify_and_decrypt(badhdr, encdata, nonce, mac) self.assertIsNone(baddata) elif mode == 'gcm': dechdr = hdr encdata, mac = enc_cipher.encrypt_and_sign(hdr, data) decdata = dec_cipher.verify_and_decrypt(hdr, encdata, mac) baddata = bad_cipher.verify_and_decrypt(hdr, encdata, mac) self.assertIsNone(baddata) else: dechdr = hdr encdata1 = enc_cipher.encrypt(data[:len(data)//2]) encdata2 = enc_cipher.encrypt(data[len(data)//2:]) decdata = dec_cipher.decrypt(encdata1) decdata += dec_cipher.decrypt(encdata2) baddata = bad_cipher.decrypt(encdata1) baddata += bad_cipher.decrypt(encdata2) self.assertNotEqual(data, baddata) self.assertEqual(hdr, dechdr) self.assertEqual(data, decdata) if libnacl_available: # pragma: no branch def test_chacha_errors(self): """Unit test error code paths in chacha cipher""" alg = b'chacha20-poly1305@openssh.com' keysize, ivsize, _, _ = get_encryption_params(alg) key = os.urandom(keysize) iv = os.urandom(ivsize) with self.subTest('Chacha20Poly1305 key size error'): with self.assertRaises(ValueError): get_cipher(alg, key[:-1], iv) with self.subTest('Chacha20Poly1305 nonce size error'): cipher = get_cipher(alg, key, iv) with self.assertRaises(ValueError): cipher.crypt_len(b'', b'') with self.assertRaises(ValueError): cipher.encrypt_and_sign(b'', b'', b'') with self.assertRaises(ValueError): cipher.verify_and_decrypt(b'', b'', b'', b'') asyncssh-1.11.1/tests/test_compression.py000066400000000000000000000022761320320510200205350ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for compression""" import os import unittest from asyncssh.compression import get_compression_algs, get_compression_params from asyncssh.compression import get_compressor, get_decompressor class TestCompression(unittest.TestCase): """Unit tests for compression module""" def test_compression_algs(self): """Unit test compression algorithms""" for alg in get_compression_algs(): with self.subTest(alg=alg): get_compression_params(alg) data = os.urandom(256) compressor = get_compressor(alg) decompressor = get_decompressor(alg) if compressor: cmpdata = compressor.compress(data) self.assertEqual(decompressor.decompress(cmpdata), data) asyncssh-1.11.1/tests/test_connection.py000066400000000000000000001205271320320510200203330ustar00rootroot00000000000000# Copyright (c) 2016-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH connection API""" import asyncio from copy import copy import os import unittest from unittest.mock import patch import asyncssh from asyncssh.cipher import get_encryption_algs from asyncssh.constants import MSG_DEBUG from asyncssh.constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS from asyncssh.constants import MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER from asyncssh.constants import MSG_GLOBAL_REQUEST from asyncssh.constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_CONFIRMATION from asyncssh.constants import MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_DATA from asyncssh.compression import get_compression_algs from asyncssh.crypto.pyca.cipher import GCMShim from asyncssh.kex import get_kex_algs from asyncssh.mac import _HMAC, _mac_handlers, get_mac_algs from asyncssh.packet import Boolean, Byte, NameList, String, UInt32 from .server import Server, ServerTestCase from .util import asynctest, gss_available, patch_gss, x509_available class _SplitClientConnection(asyncssh.SSHClientConnection): """Test SSH messages being split into multiple packets""" def data_received(self, data, datatype=None): """Handle incoming data on the connection""" super().data_received(data[:3], datatype) super().data_received(data[3:6], datatype) super().data_received(data[6:9], datatype) super().data_received(data[9:], datatype) class _ReplayKexClientConnection(asyncssh.SSHClientConnection): """Test starting SSH key exchange while it is in progress""" def replay_kex(self): """Replay last kexinit packet""" self.send_packet(self._client_kexinit) class _VersionedServerConnection(asyncssh.SSHServerConnection): """Test alternate SSH server version lines""" def __init__(self, version, leading_text, newline, *args, **kwargs): super().__init__(*args, **kwargs) self._version = version self._leading_text = leading_text self._newline = newline @classmethod def create(cls, version=b'SSH-2.0-AsyncSSH_Test', leading_text=b'', newline=b'\r\n'): """Return a connection factory which sends modified version lines""" return (lambda *args, **kwargs: cls(version, leading_text, newline, *args, **kwargs)) def _send_version(self): """Start the SSH handshake""" self._server_version = self._version self._extra.update(server_version=self._version.decode('ascii')) self._send(self._leading_text + self._version + self._newline) class _BadHostKeyServerConnection(asyncssh.SSHServerConnection): """Test returning invalid server host key""" def get_server_host_key(self): """Return the chosen server host key""" result = copy(super().get_server_host_key()) result.public_data = b'xxx' return result class _ExtInfoServerConnection(asyncssh.SSHServerConnection): """Test adding an unrecognized extension in extension info""" def _send_ext_info(self): """Send extension information""" self._extensions_sent['xxx'] = b'' super()._send_ext_info() def _failing_get_mac(alg, key): """Replace HMAC class with FailingMAC""" class _FailingMAC(_HMAC): """Test error in MAC validation""" def verify(self, seq, packet, sig): """Verify the signature of a message""" return super().verify(seq, packet + b'\xff', sig) _, hash_size, args = _mac_handlers[alg] return _FailingMAC(key, hash_size, *args) class _FailingGCMShim(GCMShim): """Test error in GCM tag verification""" def verify_and_decrypt(self, header, data, tag): """Verify the signature of and decrypt a block of data""" return super().verify_and_decrypt(header, data + b'\xff', tag) class _InternalErrorClient(asyncssh.SSHClient): """Test of internal error exception handler""" def connection_made(self, conn): """Raise an error when a new connection is opened""" # pylint: disable=unused-argument raise RuntimeError('Exception handler test') class _AbortServer(Server): """Server for testing connection abort during auth""" def begin_auth(self, username): """Abort the connection during auth""" self._conn.abort() return False class _CloseDuringAuthServer(Server): """Server for testing connection close during long auth callback""" def password_auth_supported(self): """Return that password auth is supported""" return True @asyncio.coroutine def validate_password(self, username, password): """Delay validating password""" # pylint: disable=unused-argument yield from asyncio.sleep(1) return False class _InternalErrorServer(Server): """Server for testing internal error during auth""" def begin_auth(self, username): """Raise an internal error during auth""" raise RuntimeError('Exception handler test') class _InvalidAuthBannerServer(Server): """Server for testing invalid auth banner""" def begin_auth(self, username): """Send an invalid auth banner""" self._conn.send_auth_banner(b'\xff') return False class _VersionRecordingClient(asyncssh.SSHClient): """Client for testing custom client version""" def __init__(self): self.reported_version = None def auth_banner_received(self, msg, lang): """Record the client version reported in the auth banner""" self.reported_version = msg class _VersionReportingServer(Server): """Server for testing custom client version""" def begin_auth(self, username): """Report the client's version in the auth banner""" version = self._conn.get_extra_info('client_version') self._conn.send_auth_banner(version) return False @patch_gss class _TestConnection(ServerTestCase): """Unit tests for AsyncSSH connection API""" # pylint: disable=too-many-public-methods @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server to connect to""" return (yield from cls.create_server(gss_host=())) @asyncio.coroutine def _check_version(self, *args, **kwargs): """Check alternate SSH server version lines""" with patch('asyncssh.connection.SSHServerConnection', _VersionedServerConnection.create(*args, **kwargs)): with (yield from self.connect()) as conn: pass yield from conn.wait_closed() @asynctest def test_connect_no_loop(self): """Test connecting with loop not specified""" with (yield from self.connect(loop=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_connect_failure(self): """Test failure connecting""" with self.assertRaises(OSError): yield from asyncssh.connect('0.0.0.1') @asynctest def test_connect_failure_without_agent(self): """Test failure connecting with SSH agent disabled""" with self.assertRaises(OSError): yield from asyncssh.connect('0.0.0.1', agent_path=None) @asynctest def test_split_version(self): """Test version split across two packets""" with patch('asyncssh.connection.SSHClientConnection', _SplitClientConnection): with (yield from self.connect()) as conn: pass yield from conn.wait_closed() @asynctest def test_version_1_99(self): """Test SSH server version 1.99""" yield from self._check_version(b'SSH-1.99-Test') @asynctest def test_text_before_version(self): """Test additional text before SSH server version""" yield from self._check_version(leading_text=b'Test\r\n') @asynctest def test_version_without_cr(self): """Test SSH server version with LF instead of CRLF""" yield from self._check_version(newline=b'\n') @asynctest def test_unknown_version(self): """Test unknown SSH server version""" with self.assertRaises(asyncssh.DisconnectError): yield from self._check_version(b'SSH-1.0-Test') @asynctest def test_no_server_host_keys(self): """Test starting a server with no host keys""" with self.assertRaises(ValueError): yield from asyncssh.listen(server_host_keys=[], gss_host=None) @asynctest def test_duplicate_type_server_host_keys(self): """Test starting a server with duplicate host key types""" with self.assertRaises(ValueError): yield from asyncssh.listen(server_host_keys=['skey', 'skey']) @asynctest def test_known_hosts_none(self): """Test connecting with known hosts checking disabled""" with (yield from self.connect(known_hosts=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_known_hosts_none_without_x509(self): """Test connecting with known hosts checking and X.509 disabled""" with (yield from self.connect(known_hosts=None, x509_trusted_certs=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_known_hosts_multiple_keys(self): """Test connecting with multiple trusted known hosts keys""" with (yield from self.connect(known_hosts=(['skey.pub', 'skey.pub'], [], []))) as conn: pass yield from conn.wait_closed() @asynctest def test_known_hosts_ca(self): """Test connecting with a known hosts CA""" with (yield from self.connect(known_hosts=([], ['skey.pub'], []))) as conn: pass yield from conn.wait_closed() @asynctest def test_known_hosts_bytes(self): """Test connecting with known hosts passed in as bytes""" with open('skey.pub', 'rb') as f: skey = f.read() with (yield from self.connect(known_hosts=([skey], [], []))) as conn: pass yield from conn.wait_closed() @asynctest def test_known_hosts_keylist_file(self): """Test connecting with known hosts passed as a keylist file""" with (yield from self.connect(known_hosts=('skey.pub', [], []))) as conn: pass yield from conn.wait_closed() @asynctest def test_known_hosts_sshkeys(self): """Test connecting with known hosts passed in as SSHKeys""" keylist = asyncssh.load_public_keys('skey.pub') with (yield from self.connect(known_hosts=(keylist, [], []))) as conn: pass yield from conn.wait_closed() @asynctest def test_read_known_hosts(self): """Test connecting with known hosts object from read_known_hosts""" known_hosts_path = os.path.join('.ssh', 'known_hosts') known_hosts = asyncssh.read_known_hosts(known_hosts_path) with (yield from self.connect(known_hosts=known_hosts)) as conn: pass yield from conn.wait_closed() @asynctest def test_import_known_hosts(self): """Test connecting with known hosts object from import_known_hosts""" known_hosts_path = os.path.join('.ssh', 'known_hosts') with open(known_hosts_path, 'r') as f: known_hosts = asyncssh.import_known_hosts(f.read()) with (yield from self.connect(known_hosts=known_hosts)) as conn: pass yield from conn.wait_closed() @asynctest def test_untrusted_known_hosts_key(self): """Test untrusted server host key""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=(['ckey.pub'], [], [])) @asynctest def test_untrusted_known_hosts_ca(self): """Test untrusted server CA key""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=([], ['ckey.pub'], [])) @asynctest def test_revoked_known_hosts_key(self): """Test revoked server host key""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=(['ckey.pub'], [], ['skey.pub'])) @asynctest def test_revoked_known_hosts_ca(self): """Test revoked server CA key""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=([], ['ckey.pub'], ['skey.pub'])) @asynctest def test_empty_known_hosts(self): """Test empty known hosts list""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=([], [], [])) @asynctest def test_invalid_server_host_key(self): """Test invalid server host key""" with patch('asyncssh.connection.SSHServerConnection', _BadHostKeyServerConnection): with self.assertRaises(asyncssh.DisconnectError): yield from self.connect() @asynctest def test_kex_algs(self): """Test connecting with different key exchange algorithms""" for kex in get_kex_algs(): kex = kex.decode('ascii') if kex.startswith('gss-') and not gss_available: # pragma: no cover continue with self.subTest(kex_alg=kex): with (yield from self.connect(kex_algs=[kex], gss_host='1')) as conn: pass yield from conn.wait_closed() @asynctest def test_empty_kex_algs(self): """Test connecting with an empty list of key exchange algorithms""" with self.assertRaises(ValueError): yield from self.connect(kex_algs=[]) @asynctest def test_invalid_kex_alg(self): """Test connecting with invalid key exchange algorithm""" with self.assertRaises(ValueError): yield from self.connect(kex_algs=['xxx']) @asynctest def test_unsupported_kex_alg(self): """Test connecting with unsupported key exchange algorithm""" def unsupported_kex_alg(): """Patched version of get_kex_algs to test unsupported algorithm""" return [b'fail'] + get_kex_algs() with patch('asyncssh.connection.get_kex_algs', unsupported_kex_alg): with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(kex_algs=['fail']) @asynctest def test_skip_ext_info(self): """Test not requesting extension info from the server""" def skip_ext_info(self): """Don't request extension information""" # pylint: disable=unused-argument return [] with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', skip_ext_info): with (yield from self.connect()) as conn: pass yield from conn.wait_closed() @asynctest def test_unknown_ext_info(self): """Test receiving unknown extension information""" with patch('asyncssh.connection.SSHServerConnection', _ExtInfoServerConnection): with (yield from self.connect()) as conn: pass yield from conn.wait_closed() @asynctest def test_server_ext_info(self): """Test receiving unsolicited extension information on server""" def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" asyncssh.connection.SSHConnection.send_newkeys(self, k, h) self._send_ext_info() with patch('asyncssh.connection.SSHClientConnection.send_newkeys', send_newkeys): with (yield from self.connect()) as conn: pass yield from conn.wait_closed() @asynctest def test_encryption_algs(self): """Test connecting with different encryption algorithms""" for enc in get_encryption_algs(): enc = enc.decode('ascii') with self.subTest(encryption_alg=enc): with (yield from self.connect(encryption_algs=[enc])) as conn: pass yield from conn.wait_closed() @asynctest def test_empty_encryption_algs(self): """Test connecting with an empty list of encryption algorithms""" with self.assertRaises(ValueError): yield from self.connect(encryption_algs=[]) @asynctest def test_invalid_encryption_alg(self): """Test connecting with invalid encryption algorithm""" with self.assertRaises(ValueError): yield from self.connect(encryption_algs=['xxx']) @asynctest def test_mac_algs(self): """Test connecting with different MAC algorithms""" for mac in get_mac_algs(): mac = mac.decode('ascii') with self.subTest(mac_alg=mac): with (yield from self.connect(encryption_algs=['aes128-ctr'], mac_algs=[mac])) as conn: pass yield from conn.wait_closed() @asynctest def test_mac_verify_error(self): """Test MAC validation failure""" with patch('asyncssh.connection.get_mac', _failing_get_mac): for mac in ('hmac-sha2-256-etm@openssh.com', 'hmac-sha2-256'): with self.subTest(mac_alg=mac): with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(encryption_algs=['aes128-ctr'], mac_algs=[mac]) @asynctest def test_gcm_verify_error(self): """Test GCM tag validation failure""" with patch('asyncssh.crypto.pyca.cipher.GCMShim', _FailingGCMShim): with self.assertRaises(asyncssh.DisconnectError): yield from self.connect( encryption_algs=['aes128-gcm@openssh.com']) @asynctest def test_empty_mac_algs(self): """Test connecting with an empty list of MAC algorithms""" with self.assertRaises(ValueError): yield from self.connect(mac_algs=[]) @asynctest def test_invalid_mac_alg(self): """Test connecting with invalid MAC algorithm""" with self.assertRaises(ValueError): yield from self.connect(mac_algs=['xxx']) @asynctest def test_compression_algs(self): """Test connecting with different compression algorithms""" for cmp in get_compression_algs(): cmp = cmp.decode('ascii') with self.subTest(cmp_alg=cmp): with (yield from self.connect(compression_algs=[cmp])) as conn: pass yield from conn.wait_closed() @asynctest def test_no_compression(self): """Test connecting with compression disabled""" with (yield from self.connect(compression_algs=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_invalid_cmp_alg(self): """Test connecting with invalid compression algorithm""" with self.assertRaises(ValueError): yield from self.connect(compression_algs=['xxx']) @asynctest def test_disconnect(self): """Test sending disconnect message""" conn = yield from self.connect() conn.disconnect(asyncssh.DISC_BY_APPLICATION, 'Closing') yield from conn.wait_closed() @asynctest def test_invalid_disconnect(self): """Test sending disconnect message with invalid Unicode in it""" conn = yield from self.connect() conn.disconnect(asyncssh.DISC_BY_APPLICATION, b'\xff') yield from conn.wait_closed() @asynctest def test_debug(self): """Test sending debug message""" with (yield from self.connect()) as conn: conn.send_debug('debug') yield from conn.wait_closed() @asynctest def test_invalid_debug(self): """Test sending debug message with invalid Unicode in it""" conn = yield from self.connect() conn.send_debug(b'\xff') yield from conn.wait_closed() @asynctest def test_invalid_service_request(self): """Test invalid service request""" conn = yield from self.connect() conn.send_packet(Byte(MSG_SERVICE_REQUEST), String('xxx')) yield from conn.wait_closed() @asynctest def test_invalid_service_accept(self): """Test invalid service accept""" conn = yield from self.connect() conn.send_packet(Byte(MSG_SERVICE_ACCEPT), String('xxx')) yield from conn.wait_closed() @asynctest def test_packet_decode_error(self): """Test SSH packet decode error""" conn = yield from self.connect() conn.send_packet(Byte(MSG_DEBUG)) yield from conn.wait_closed() @asynctest def test_unknown_packet(self): """Test unknown SSH packet""" with (yield from self.connect()) as conn: conn.send_packet(b'\xff') yield from asyncio.sleep(0.1) yield from conn.wait_closed() @asynctest def test_rekey(self): """Test SSH re-keying""" with (yield from self.connect(rekey_bytes=1)) as conn: yield from asyncio.sleep(0.1) conn.send_debug('test') yield from asyncio.sleep(0.1) yield from conn.wait_closed() @asynctest def test_kex_in_progress(self): """Test starting SSH key exchange while it is in progress""" with patch('asyncssh.connection.SSHClientConnection', _ReplayKexClientConnection): conn = yield from self.connect() conn.replay_kex() conn.replay_kex() yield from conn.wait_closed() @asynctest def test_no_matching_kex_algs(self): """Test no matching key exchange algorithms""" conn = yield from self.connect() conn.send_packet(Byte(MSG_KEXINIT), os.urandom(16), NameList([b'xxx']), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), Boolean(False), UInt32(0)) yield from conn.wait_closed() @asynctest def test_no_matching_host_key_algs(self): """Test no matching server host key algorithms""" conn = yield from self.connect() conn.send_packet(Byte(MSG_KEXINIT), os.urandom(16), NameList([b'ecdh-sha2-nistp521']), NameList([b'xxx']), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), Boolean(False), UInt32(0)) yield from conn.wait_closed() @asynctest def test_invalid_newkeys(self): """Test invalid new keys request""" conn = yield from self.connect() conn.send_packet(Byte(MSG_NEWKEYS)) yield from conn.wait_closed() @asynctest def test_invalid_userauth_service(self): """Test invalid service in userauth request""" conn = yield from self.connect() conn.send_packet(Byte(MSG_USERAUTH_REQUEST), String('guest'), String('xxx'), String('none')) yield from conn.wait_closed() @asynctest def test_invalid_username(self): """Test invalid username in userauth request""" conn = yield from self.connect() conn.send_packet(Byte(MSG_USERAUTH_REQUEST), String(b'\xff'), String('ssh-connection'), String('none')) yield from conn.wait_closed() @asynctest def test_extra_userauth_request(self): """Test userauth request after auth is complete""" with (yield from self.connect()) as conn: conn.send_packet(Byte(MSG_USERAUTH_REQUEST), String('guest'), String('ssh-connection'), String('none')) yield from asyncio.sleep(0.1) yield from conn.wait_closed() @asynctest def test_unexpected_userauth_success(self): """Test unexpected userauth success response""" conn = yield from self.connect() conn.send_packet(Byte(MSG_USERAUTH_SUCCESS)) yield from conn.wait_closed() @asynctest def test_unexpected_userauth_failure(self): """Test unexpected userauth failure response""" conn = yield from self.connect() conn.send_packet(Byte(MSG_USERAUTH_FAILURE), NameList([]), Boolean(False)) yield from conn.wait_closed() @asynctest def test_unexpected_userauth_banner(self): """Test unexpected userauth banner""" conn = yield from self.connect() conn.send_packet(Byte(MSG_USERAUTH_BANNER), String(''), String('')) yield from conn.wait_closed() @asynctest def test_invalid_global_request(self): """Test invalid global request""" conn = yield from self.connect() conn.send_packet(Byte(MSG_GLOBAL_REQUEST), String(b'\xff'), Boolean(True)) yield from conn.wait_closed() @asynctest def test_unexpected_global_response(self): """Test unexpected global response""" conn = yield from self.connect() conn.send_packet(Byte(MSG_GLOBAL_REQUEST), String('xxx'), Boolean(True)) yield from conn.wait_closed() @asynctest def test_invalid_channel_open(self): """Test invalid channel open request""" conn = yield from self.connect() conn.send_packet(Byte(MSG_CHANNEL_OPEN), String(b'\xff'), UInt32(0), UInt32(0), UInt32(0)) yield from conn.wait_closed() @asynctest def test_unknown_channel_type(self): """Test unknown channel open type""" conn = yield from self.connect() conn.send_packet(Byte(MSG_CHANNEL_OPEN), String('xxx'), UInt32(0), UInt32(0), UInt32(0)) yield from conn.wait_closed() @asynctest def test_invalid_channel_open_confirmation_number(self): """Test invalid channel number in open confirmation""" conn = yield from self.connect() conn.send_packet(Byte(MSG_CHANNEL_OPEN_CONFIRMATION), UInt32(0xff), UInt32(0), UInt32(0), UInt32(0)) yield from conn.wait_closed() @asynctest def test_invalid_channel_open_failure_number(self): """Test invalid channel number in open failure""" conn = yield from self.connect() conn.send_packet(Byte(MSG_CHANNEL_OPEN_FAILURE), UInt32(0xff), UInt32(0), String(''), String('')) yield from conn.wait_closed() @asynctest def test_invalid_channel_open_failure_reason(self): """Test invalid reason in channel open failure""" conn = yield from self.connect() conn.send_packet(Byte(MSG_CHANNEL_OPEN_FAILURE), UInt32(0), UInt32(0), String(b'\xff'), String('')) yield from conn.wait_closed() @asynctest def test_invalid_channel_open_failure_language(self): """Test invalid language in channel open failure""" conn = yield from self.connect() conn.send_packet(Byte(MSG_CHANNEL_OPEN_FAILURE), UInt32(0), UInt32(0), String(''), String(b'\xff')) yield from conn.wait_closed() @asynctest def test_invalid_data_channel_number(self): """Test invalid channel number in channel data message""" conn = yield from self.connect() conn.send_packet(Byte(MSG_CHANNEL_DATA), String('')) yield from conn.wait_closed() @asynctest def test_internal_error(self): """Test internal error in client callback""" with self.assertRaises(RuntimeError): yield from self.create_connection(_InternalErrorClient) class _TestConnectionAbort(ServerTestCase): """Unit test for connection abort""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which aborts connections during auth""" return (yield from cls.create_server(_AbortServer)) @asynctest def test_abort(self): """Test connection abort""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect() class _TestConnectionCloseDurngAuth(ServerTestCase): """Unit test for connection close during long auth callback""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which aborts connections during auth""" return (yield from cls.create_server(_CloseDuringAuthServer)) @asynctest def test_close_during_auth(self): """Test connection close during long auth callback""" with self.assertRaises(asyncio.TimeoutError): yield from asyncio.wait_for(self.connect(username='user', password=''), 0.5) @unittest.skipUnless(x509_available, 'X.509 not available') class _TestServerX509Self(ServerTestCase): """Unit test for server with self-signed X.509 host certificate""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server with a self-signed X.509 host certificate""" return (yield from cls.create_server( server_host_keys=['skey_x509_self'])) @asynctest def test_connect_x509_self(self): """Test connecting with X.509 self-signed certificate""" with (yield from self.connect(known_hosts=([], [], [], ['skey_x509_self.pem'], [], [], []))) as conn: pass yield from conn.wait_closed() @asynctest def test_connect_x509_untrusted_self(self): """Test connecting with untrusted X.509 self-signed certficate""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect( known_hosts=([], [], [], ['root_ca_cert.pem'], [], [], [])) @asynctest def test_connect_x509_revoked_self(self): """Test connecting with revoked X.509 self-signed certficate""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect( known_hosts=([], [], [], ['root_ca_cert.pem'], ['skey_x509_self.pem'], [], [])) @asynctest def test_connect_x509_trusted_subject(self): """Test connecting to server with trusted X.509 subject name""" with (yield from self.connect( known_hosts=([], [], [], [], [], ['OU=name'], ['OU=name1']), x509_trusted_certs=['skey_x509_self.pem'])) as conn: pass yield from conn.wait_closed() @asynctest def test_connect_x509_untrusted_subject(self): """Test connecting to server with untrusted X.509 subject name""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect( known_hosts=([], [], [], [], [], ['OU=name1'], []), x509_trusted_certs=['skey_x509_self.pem']) @asynctest def test_connect_x509_revoked_subject(self): """Test connecting to server with revoked X.509 subject name""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect( known_hosts=([], [], [], [], [], [], ['OU=name']), x509_trusted_certs=['skey_x509_self.pem']) @asynctest def test_connect_x509_disabled(self): """Test connecting to X.509 server with X.509 disabled""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect( known_hosts=([], [], [], [], [], ['OU=name'], []), x509_trusted_certs=None) @unittest.skipUnless(x509_available, 'X.509 not available') class _TestServerX509Chain(ServerTestCase): """Unit test for server with X.509 host certificate chain""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server with an X.509 host certificate chain""" return (yield from cls.create_server( server_host_keys=['skey_x509_chain'])) @asynctest def test_connect_x509_chain(self): """Test connecting with X.509 certificate chain""" with (yield from self.connect(known_hosts=([], [], [], ['root_ca_cert.pem'], [], [], []))) as conn: pass yield from conn.wait_closed() @asynctest def test_connect_x509_chain_cert_path(self): """Test connecting with X.509 certificate and certificate path""" with (yield from self.connect(x509_trusted_cert_paths=['cert_path'], known_hosts=b'\n')) as conn: pass yield from conn.wait_closed() @asynctest def test_connect_x509_untrusted_root(self): """Test connecting to server with untrusted X.509 root CA""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=([], [], [], ['skey_x509_self.pem'], [], [], [])) @asynctest def test_connect_x509_untrusted_root_cert_path(self): """Test connecting to server with untrusted X.509 root CA""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=b'\n') @asynctest def test_connect_x509_revoked_intermediate(self): """Test connecting to server with revoked X.509 intermediate CA""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=([], [], [], ['root_ca_cert.pem'], ['int_ca_cert.pem'], [], [])) @asynctest def test_invalid_x509_path(self): """Test passing in invalid trusted X.509 certificate path""" with self.assertRaises(ValueError): yield from self.connect(x509_trusted_cert_paths='xxx') class _TestServerNoLoop(ServerTestCase): """Unit test for server with no loop specified""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which has no loop specified""" return (yield from cls.create_server(loop=None)) @asynctest def test_server_no_loop(self): """Test server with no loop specified""" with (yield from self.connect()) as conn: pass yield from conn.wait_closed() @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestServerNoHostKey(ServerTestCase): """Unit test for server with no server host key""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which sets no server host keys""" return (yield from cls.create_server(server_host_keys=None, gss_host='1')) @asynctest def test_gss_with_no_host_key(self): """Test GSS key exchange with no server host key specified""" with (yield from self.connect(known_hosts=b'\n', gss_host='1', x509_trusted_certs=None, x509_trusted_cert_paths=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_dh_with_no_host_key(self): """Test failure of DH key exchange with no server host key specified""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect() class _TestServerInternalError(ServerTestCase): """Unit test for server internal error during auth""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which raises an error during auth""" return (yield from cls.create_server(_InternalErrorServer)) @asynctest def test_server_internal_error(self): """Test server internal error during auth""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect() class _TestInvalidAuthBanner(ServerTestCase): """Unit test for invalid auth banner""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which sends invalid auth banner""" return (yield from cls.create_server(_InvalidAuthBannerServer)) @asynctest def test_abort(self): """Test server sending invalid auth banner""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect() class _TestExpiredServerHostCertificate(ServerTestCase): """Unit tests for expired server host certificate""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server with an expired host certificate""" return (yield from cls.create_server(server_host_keys=['exp_skey'])) @asynctest def test_expired_server_host_cert(self): """Test expired server host certificate""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(known_hosts=([], ['skey.pub'], [])) class _TestCustomClientVersion(ServerTestCase): """Unit test for custom SSH client version""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which sends client version in auth banner""" return (yield from cls.create_server(_VersionReportingServer)) @asyncio.coroutine def _check_client_version(self, version): """Check custom client version""" conn, client = \ yield from self.create_connection(_VersionRecordingClient, client_version=version) with conn: self.assertEqual(client.reported_version, 'SSH-2.0-custom') yield from conn.wait_closed() @asynctest def test_custom_client_version(self): """Test custom client version""" yield from self._check_client_version('custom') @asynctest def test_custom_client_version_bytes(self): """Test custom client version set as bytes""" yield from self._check_client_version(b'custom') @asynctest def test_long_client_version(self): """Test client version which is too long""" with self.assertRaises(ValueError): yield from self.connect(client_version=246*'a') @asynctest def test_nonprintable_client_version(self): """Test client version with non-printable character""" with self.assertRaises(ValueError): yield from self.connect(client_version='xxx\0') class _TestCustomServerVersion(ServerTestCase): """Unit test for custom SSH server version""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which sends a custom version""" return (yield from cls.create_server(server_version='custom')) @asynctest def test_custom_server_version(self): """Test custom server version""" with (yield from self.connect()) as conn: version = conn.get_extra_info('server_version') self.assertEqual(version, 'SSH-2.0-custom') yield from conn.wait_closed() @asynctest def test_long_server_version(self): """Test server version which is too long""" with self.assertRaises(ValueError): yield from self.create_server(server_version=246*'a') @asynctest def test_nonprintable_server_version(self): """Test server version with non-printable character""" with self.assertRaises(ValueError): yield from self.create_server(server_version='xxx\0') asyncssh-1.11.1/tests/test_connection_auth.py000066400000000000000000001164761320320510200213640ustar00rootroot00000000000000# Copyright (c) 2016-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH connection authentication""" import asyncio import os import unittest from unittest.mock import patch import asyncssh from asyncssh.packet import String from asyncssh.public_key import CERT_TYPE_USER from .server import Server, ServerTestCase from .util import asynctest, gss_available, patch_gss, make_certificate from .util import x509_available class _AsyncGSSServer(asyncssh.SSHServer): """Server for testing async GSS authentication""" @asyncio.coroutine def validate_gss_principal(self, username, user_principal, host_principal): """Return whether password is valid for this user""" return super().validate_gss_principal(username, user_principal, host_principal) class _PublicKeyClient(asyncssh.SSHClient): """Test client public key authentication""" def __init__(self, keylist, delay=0): self._keylist = keylist self._delay = delay @asyncio.coroutine def public_key_auth_requested(self): """Return a public key to authenticate with""" if self._delay: yield from asyncio.sleep(self._delay) return self._keylist.pop(0) if self._keylist else None class _AsyncPublicKeyClient(_PublicKeyClient): """Test async client public key authentication""" @asyncio.coroutine def public_key_auth_requested(self): """Return a public key to authenticate with""" return super().public_key_auth_requested() class _PublicKeyServer(Server): """Server for testing public key authentication""" def __init__(self, client_keys=(), authorized_keys=None): super().__init__() self._client_keys = client_keys self._authorized_keys = authorized_keys def connection_made(self, conn): """Called when a connection is made""" super().connection_made(conn) conn.send_auth_banner('auth banner') def begin_auth(self, username): """Handle client authentication request""" if self._authorized_keys: self._conn.set_authorized_keys(self._authorized_keys) else: self._client_keys = asyncssh.load_public_keys(self._client_keys) return True def public_key_auth_supported(self): """Return whether or not public key authentication is supported""" return True def validate_public_key(self, username, key): """Return whether key is an authorized client key for this user""" return key in self._client_keys def validate_ca_key(self, username, key): """Return whether key is an authorized CA key for this user""" return key in self._client_keys class _AsyncPublicKeyServer(_PublicKeyServer): """Server for testing async public key authentication""" @asyncio.coroutine def validate_public_key(self, username, key): """Return whether key is an authorized client key for this user""" return super().validate_public_key(username, key) @asyncio.coroutine def validate_ca_key(self, username, key): """Return whether key is an authorized CA key for this user""" return super().validate_ca_key(username, key) class _PasswordClient(asyncssh.SSHClient): """Test client password authentication""" def __init__(self, password, old_password, new_password): self._password = password self._old_password = old_password self._new_password = new_password def password_auth_requested(self): """Return a password to authenticate with""" if self._password: result = self._password self._password = None return result else: return None def password_change_requested(self, prompt, lang): """Change the client's password""" return self._old_password, self._new_password class _AsyncPasswordClient(_PasswordClient): """Test async client password authentication""" @asyncio.coroutine def password_auth_requested(self): """Return a password to authenticate with""" return super().password_auth_requested() @asyncio.coroutine def password_change_requested(self, prompt, lang): """Change the client's password""" return super().password_change_requested(prompt, lang) class _PasswordServer(Server): """Server for testing password authentication""" def password_auth_supported(self): """Enable password authentication""" return True def validate_password(self, username, password): """Accept password of pw, trigger password change on oldpw""" if password == 'oldpw': raise asyncssh.PasswordChangeRequired('Password change required') else: return password == 'pw' def change_password(self, username, old_password, new_password): """Only allow password change from password oldpw""" return old_password == 'oldpw' class _AsyncPasswordServer(_PasswordServer): """Server for testing async password authentication""" @asyncio.coroutine def validate_password(self, username, password): """Return whether password is valid for this user""" return super().validate_password(username, password) @asyncio.coroutine def change_password(self, username, old_password, new_password): """Handle a request to change a user's password""" return super().change_password(username, old_password, new_password) class _KbdintClient(asyncssh.SSHClient): """Test keyboard-interactive client auth""" def __init__(self, responses): self._responses = responses def kbdint_auth_requested(self): """Return the list of supported keyboard-interactive auth methods""" return '' if self._responses else None def kbdint_challenge_received(self, name, instructions, lang, prompts): """Return responses to a keyboard-interactive auth challenge""" # pylint: disable=unused-argument if len(prompts) == 0: return [] elif self._responses: result = self._responses self._responses = None return result else: return None class _AsyncKbdintClient(_KbdintClient): """Test keyboard-interactive client auth""" @asyncio.coroutine def kbdint_auth_requested(self): """Return the list of supported keyboard-interactive auth methods""" return super().kbdint_auth_requested() @asyncio.coroutine def kbdint_challenge_received(self, name, instructions, lang, prompts): """Return responses to a keyboard-interactive auth challenge""" return super().kbdint_challenge_received(name, instructions, lang, prompts) class _KbdintServer(Server): """Server for testing keyboard-interactive authentication""" def __init__(self): super().__init__() self._kbdint_round = 0 def kbdint_auth_supported(self): """Enable keyboard-interactive authentication""" return True def get_kbdint_challenge(self, username, lang, submethods): """Return an initial challenge with only instructions""" return '', 'instructions', '', [] def validate_kbdint_response(self, username, responses): """Return a password challenge after the instructions""" if self._kbdint_round == 0: result = ('', '', '', [('Password:', False)]) else: if len(responses) == 1 and responses[0] == 'kbdint': result = True else: result = ('', '', '', [('Other Challenge:', True)]) self._kbdint_round += 1 return result class _AsyncKbdintServer(_KbdintServer): """Server for testing async keyboard-interactive authentication""" @asyncio.coroutine def get_kbdint_challenge(self, username, lang, submethods): """Return a keyboard-interactive auth challenge""" return super().get_kbdint_challenge(username, lang, submethods) @asyncio.coroutine def validate_kbdint_response(self, username, responses): """Return whether the keyboard-interactive response is valid for this user""" return super().validate_kbdint_response(username, responses) class _UnknownAuthClientConnection(asyncssh.connection.SSHClientConnection): """Test getting back an unknown auth method from the SSH server""" def try_next_auth(self): """Attempt client authentication using an unknown method""" self._auth_methods = [b'unknown'] + self._auth_methods super().try_next_auth() @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestGSSAuth(ServerTestCase): """Unit tests for GSS authentication""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports GSS authentication""" return (yield from cls.create_server(_AsyncGSSServer, gss_host='1')) @asynctest def test_gss_kex_auth(self): """Test GSS key exchange authentication""" with (yield from self.connect(kex_algs=['gss-gex-sha1'], username='user', gss_host='1')) as conn: pass yield from conn.wait_closed() @asynctest def test_gss_mic_auth(self): """Test GSS MIC authentication""" with (yield from self.connect(kex_algs=['ecdh-sha2-nistp256'], username='user', gss_host='1')) as conn: pass yield from conn.wait_closed() @asynctest def test_gss_auth_unavailable(self): """Test GSS authentication being unavailable""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='user1', gss_host=()) @asynctest def test_gss_client_error(self): """Test GSS client error""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(gss_host='1,init_error', username='user') @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestGSSServerError(ServerTestCase): """Unit tests for GSS server error""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which raises an error on GSS authentication""" return (yield from cls.create_server(gss_host='1,init_error')) @asynctest def test_gss_server_error(self): """Test GSS error on server""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='user') @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestGSSFQDN(ServerTestCase): """Unit tests for GSS server error""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which raises an error on GSS authentication""" def mock_gethostname(): """Return a non-fully-qualified hostname""" return 'host' def mock_getfqdn(): """Confirm getfqdn is called on relative hostnames""" return '1' with patch('socket.gethostname', mock_gethostname): with patch('socket.getfqdn', mock_getfqdn): return (yield from cls.create_server(gss_host=())) @asynctest def test_gss_fqdn_lookup(self): """Test GSS FQDN lookup""" with (yield from self.connect(username='user', gss_host=())) as conn: pass yield from conn.wait_closed() class _TestPublicKeyAuth(ServerTestCase): """Unit tests for public key authentication""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports public key authentication""" return (yield from cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys')) @asyncio.coroutine def _connect_publickey(self, keylist, test_async=False): """Open a connection to test public key auth""" def client_factory(): """Return an SSHClient to use to do public key auth""" cls = _AsyncPublicKeyClient if test_async else _PublicKeyClient return cls(keylist) conn, _ = yield from self.create_connection(client_factory, username='ckey', client_keys=None) return conn @asynctest def test_agent_auth(self): """Test connecting with ssh-agent authentication""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') with (yield from self.connect(username='ckey')) as conn: pass yield from conn.wait_closed() @asynctest def test_agent_signature_algs(self): """Test ssh-agent keys with specific signature algorithms""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') for alg in ('ssh-rsa', 'rsa-sha2-256', 'rsa-sha2-512'): with (yield from self.connect(username='ckey', signature_algs=[alg])) as conn: pass yield from conn.wait_closed() @asynctest def test_agent_auth_failure(self): """Test failure connecting with ssh-agent authentication""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') with patch.dict(os.environ, HOME='xxx'): with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', agent_path='xxx', known_hosts='.ssh/known_hosts') @asynctest def test_agent_auth_unset(self): """Test connecting with no local keys and no ssh-agent configured""" with patch.dict(os.environ, HOME='xxx', SSH_AUTH_SOCK=''): with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', known_hosts='.ssh/known_hosts') @asynctest def test_public_key_auth(self): """Test connecting with public key authentication""" with (yield from self.connect(username='ckey', client_keys='ckey')) as conn: pass yield from conn.wait_closed() @asynctest def test_public_key_signature_algs(self): """Test public key authentication with specific signature algorithms""" for alg in ('ssh-rsa', 'rsa-sha2-256', 'rsa-sha2-512'): with (yield from self.connect(username='ckey', client_keys='ckey', signature_algs=[alg])) as conn: pass yield from conn.wait_closed() @asynctest def test_no_server_signature_algs(self): """Test a server which doesn't advertise signature algorithms""" def skip_ext_info(self): """Don't send extension information""" # pylint: disable=unused-argument return [] with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', skip_ext_info): with (yield from self.connect(username='ckey', client_keys='ckey', agent_path=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_default_public_key_auth(self): """Test connecting with default public key authentication""" with (yield from self.connect(username='ckey', agent_path=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_invalid_default_key(self): """Test connecting with invalid default client key""" key_path = os.path.join('.ssh', 'id_dsa') with open(key_path, 'w') as f: f.write('-----XXX-----') with self.assertRaises(asyncssh.KeyImportError): yield from self.connect(username='ckey', agent_path=None) os.remove(key_path) @asynctest def test_client_key_bytes(self): """Test client key passed in as bytes""" with open('ckey', 'rb') as f: ckey = f.read() with (yield from self.connect(username='ckey', client_keys=[ckey])) as conn: pass yield from conn.wait_closed() @asynctest def test_client_key_sshkey(self): """Test client key passed in as an SSHKey""" ckey = asyncssh.read_private_key('ckey') with (yield from self.connect(username='ckey', client_keys=[ckey])) as conn: pass yield from conn.wait_closed() @asynctest def test_client_key_keypairs(self): """Test client keys passed in as a list of SSHKeyPairs""" keys = asyncssh.load_keypairs('ckey') with (yield from self.connect(username='ckey', client_keys=keys)) as conn: pass yield from conn.wait_closed() @asynctest def test_client_key_agent_keypairs(self): """Test client keys passed in as a list of SSHAgentKeyPairs""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') agent = yield from asyncssh.connect_agent() for key in (yield from agent.get_keys()): with (yield from self.connect(username='ckey', client_keys=[key])) as conn: pass yield from conn.wait_closed() agent.close() @asynctest def test_untrusted_client_key(self): """Test untrusted client key""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys='skey') @asynctest def test_missing_cert(self): """Test missing client certificate""" with self.assertRaises(OSError): yield from self.connect(username='ckey', client_keys=[('ckey', 'xxx')]) @asynctest def test_expired_cert(self): """Test expired certificate""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, ckey, ['ckey'], valid_before=1) with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys=[(skey, cert)]) @asynctest def test_allowed_address(self): """Test allowed address in certificate""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, ckey, ['ckey'], options={'source-address': String('0.0.0.0/0,::/0')}) with (yield from self.connect(username='ckey', client_keys=[(skey, cert)])) as conn: pass yield from conn.wait_closed() @asynctest def test_disallowed_address(self): """Test disallowed address in certificate""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, ckey, ['ckey'], options={'source-address': String('0.0.0.0')}) with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys=[(skey, cert)]) @asynctest def test_untrusted_ca(self): """Test untrusted CA""" skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, skey, ['skey']) with self.assertRaises(ValueError): yield from self.connect(username='ckey', client_keys=[('ckey', cert)]) @asynctest def test_callback(self): """Test connecting with public key authentication using callback""" with (yield from self._connect_publickey(['ckey'], test_async=True)) as conn: pass yield from conn.wait_closed() @asynctest def test_callback_sshkeypair(self): """Test client key passed in as an SSHKeyPair by callback""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') agent = yield from asyncssh.connect_agent() keylist = yield from agent.get_keys() with (yield from self._connect_publickey(keylist)) as conn: pass yield from conn.wait_closed() agent.close() @asynctest def test_callback_untrusted_client_key(self): """Test failure connecting with public key authentication callback""" with self.assertRaises(asyncssh.DisconnectError): yield from self._connect_publickey(['skey']) @asynctest def test_unknown_auth(self): """Test server returning an unknown auth method before public key""" with patch('asyncssh.connection.SSHClientConnection', _UnknownAuthClientConnection): with (yield from self.connect(username='ckey', client_keys='ckey', agent_path=None)) as conn: pass yield from conn.wait_closed() class _TestPublicKeyAsyncServerAuth(_TestPublicKeyAuth): """Unit tests for public key authentication with async server callbacks""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports async public key auth""" def server_factory(): """Return an SSH server which calls set_authorized_keys""" return _AsyncPublicKeyServer(client_keys=['ckey.pub', 'ckey_ecdsa.pub']) return (yield from cls.create_server(server_factory)) class _TestLimitedSignatureAlgs(ServerTestCase): """Unit tests for limited public key signature algorithms""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports public key authentication""" return (yield from cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys', signature_algs=['ssh-rsa', 'rsa-sha2-512'])) @asynctest def test_mismatched_signature_algs(self): """Test mismatched signature algorithms""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys='ckey', signature_algs=['rsa-sha2-256']) @asynctest def test_signature_alg_fallback(self): """Test fall back to default signature algorithm""" with (yield from self.connect(username='ckey', client_keys='ckey', signature_algs=['rsa-sha2-256', 'ssh-rsa'])) as conn: pass yield from conn.wait_closed() class _TestSetAuthorizedKeys(ServerTestCase): """Unit tests for public key authentication with set_authorized_keys""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports public key authentication""" def server_factory(): """Return an SSH server which calls set_authorized_keys""" return _PublicKeyServer(authorized_keys='authorized_keys') return (yield from cls.create_server(server_factory)) @asynctest def test_set_authorized_keys(self): """Test set_authorized_keys method on server""" with (yield from self.connect(username='ckey', client_keys='ckey')) as conn: pass yield from conn.wait_closed() @asynctest def test_cert_principals(self): """Test certificate principals check""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey']) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: pass yield from conn.wait_closed() class _TestPreloadedAuthorizedKeys(ServerTestCase): """Unit tests for authentication with pre-loaded authorized keys""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports public key authentication""" def server_factory(): """Return an SSH server which calls set_authorized_keys""" authorized_keys = asyncssh.read_authorized_keys('authorized_keys') return _PublicKeyServer(authorized_keys=authorized_keys) return (yield from cls.create_server(server_factory)) @asynctest def test_pre_loaded_authorized_keys(self): """Test set_authorized_keys with pre-loaded authorized keys""" with (yield from self.connect(username='ckey', client_keys='ckey')) as conn: pass yield from conn.wait_closed() @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Auth(ServerTestCase): """Unit tests for X.509 certificate authentication""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports public key authentication""" return (yield from cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys_x509')) @asynctest def test_x509_self(self): """Test connecting with X.509 self-signed certificate""" with (yield from self.connect(username='ckey', client_keys=['ckey_x509_self'])) as conn: pass yield from conn.wait_closed() @asynctest def test_x509_chain(self): """Test connecting with X.509 certificate chain""" with (yield from self.connect(username='ckey', client_keys=['ckey_x509_chain'])) as conn: pass yield from conn.wait_closed() @asynctest def test_x509_incomplete_chain(self): """Test connecting with incomplete X.509 certificate chain""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys=[('ckey_x509_chain', 'ckey_x509_partial.pem')]) @asynctest def test_x509_untrusted_cert(self): """Test connecting with untrusted X.509 certificate chain""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys=['skey_x509_chain']) @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509AuthDisabled(ServerTestCase): """Unit tests for disabled X.509 certificate authentication""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which doesn't support X.509 authentication""" return (yield from cls.create_server( _PublicKeyServer, x509_trusted_certs=None, authorized_client_keys='authorized_keys')) @asynctest def test_failed_x509_auth(self): """Test connect failure with X.509 certificate""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys=['ckey_x509_self'], signature_algs=['x509v3-ssh-rsa']) @asynctest def test_non_x509(self): """Test connecting without an X.509 certificate""" with (yield from self.connect(username='ckey', client_keys=['ckey'])) as conn: pass yield from conn.wait_closed() @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Subject(ServerTestCase): """Unit tests for X.509 certificate authentication by subject name""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports public key authentication""" authorized_keys = asyncssh.import_authorized_keys( 'x509v3-ssh-rsa subject=OU=name\n') return (yield from cls.create_server( _PublicKeyServer, authorized_client_keys=authorized_keys, x509_trusted_certs=['ckey_x509_self.pub'])) @asynctest def test_x509_subject(self): """Test authenticating X.509 certificate by subject name""" with (yield from self.connect(username='ckey', client_keys=['ckey_x509_self'])) as conn: pass yield from conn.wait_closed() @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Untrusted(ServerTestCase): """Unit tests for X.509 authentication with no trusted certificates""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports public key authentication""" return (yield from cls.create_server( _PublicKeyServer, authorized_client_keys=None)) @asynctest def test_x509_untrusted(self): """Test untrusted X.509 self-signed certificate""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys=['ckey_x509_self']) @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Disabled(ServerTestCase): """Unit tests for X.509 authentication with server support disabled""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server with X.509 authentication disabled""" return (yield from cls.create_server(_PublicKeyServer, x509_purposes=None)) @asynctest def test_x509_disabled(self): """Test X.509 client certificate with server support disabled""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='ckey', client_keys='skey_x509_self') class _TestPasswordAuth(ServerTestCase): """Unit tests for password authentication""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports password authentication""" return (yield from cls.create_server(_PasswordServer)) @asyncio.coroutine def _connect_password(self, username, password, old_password='', new_password='', test_async=False): """Open a connection to test password authentication""" def client_factory(): """Return an SSHClient to use to do password change""" cls = _AsyncPasswordClient if test_async else _PasswordClient return cls(password, old_password, new_password) conn, _ = yield from self.create_connection(client_factory, username=username, client_keys=None) return conn @asynctest def test_password_auth(self): """Test connecting with password authentication""" with (yield from self.connect(username='pw', password='pw', client_keys=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_password_auth_failure(self): """Test _failure connecting with password authentication""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='pw', password='badpw', client_keys=None) @asynctest def test_password_auth_callback(self): """Test connecting with password authentication callback""" with (yield from self._connect_password('pw', 'pw', test_async=True)) as conn: pass yield from conn.wait_closed() @asynctest def test_password_auth_callback_failure(self): """Test failure connecting with password authentication callback""" with self.assertRaises(asyncssh.DisconnectError): yield from self._connect_password('pw', 'badpw') @asynctest def test_password_change(self): """Test password change""" with (yield from self._connect_password('pw', 'oldpw', 'oldpw', 'pw', test_async=True)) as conn: pass yield from conn.wait_closed() @asynctest def test_password_change_failure(self): """Test failure of password change""" with self.assertRaises(asyncssh.DisconnectError): yield from self._connect_password('pw', 'oldpw', 'badpw', 'pw') class _TestPasswordAsyncServerAuth(_TestPasswordAuth): """Unit tests for password authentication with async server callbacks""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports async password authentication""" return (yield from cls.create_server(_AsyncPasswordServer)) class _TestKbdintAuth(ServerTestCase): """Unit tests for keyboard-interactive authentication""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports keyboard-interactive auth""" return (yield from cls.create_server(_KbdintServer)) @asyncio.coroutine def _connect_kbdint(self, username, responses, test_async=False): """Open a connection to test keyboard-interactive auth""" def client_factory(): """Return an SSHClient to use to do keyboard-interactive auth""" cls = _AsyncKbdintClient if test_async else _KbdintClient return cls(responses) conn, _ = yield from self.create_connection(client_factory, username=username, client_keys=None) return conn @asynctest def test_kbdint_auth(self): """Test connecting with keyboard-interactive authentication""" with (yield from self.connect(username='kbdint', password='kbdint', client_keys=None)) as conn: pass yield from conn.wait_closed() @asynctest def test_kbdint_auth_failure(self): """Test failure connecting with keyboard-interactive authentication""" with self.assertRaises(asyncssh.DisconnectError): yield from self.connect(username='kbdint', password='badpw', client_keys=None) @asynctest def test_kbdint_auth_callback(self): """Test keyboard-interactive auth callback""" with (yield from self._connect_kbdint('kbdint', ['kbdint'], test_async=True)) as conn: pass yield from conn.wait_closed() @asynctest def test_kbdint_auth_callback_faliure(self): """Test failure connection with keyboard-interactive auth callback""" with self.assertRaises(asyncssh.DisconnectError): yield from self._connect_kbdint('kbdint', ['badpw']) class _TestKbdintAsyncServerAuth(_TestKbdintAuth): """Unit tests for keyboard-interactive auth with async server callbacks""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports async kbd-int auth""" return (yield from cls.create_server(_AsyncKbdintServer)) class _TestKbdintPasswordServerAuth(ServerTestCase): """Unit tests for keyboard-interactive auth with server password auth""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports server password auth""" return (yield from cls.create_server(_PasswordServer)) @asyncio.coroutine def _connect_kbdint(self, username, responses): """Open a connection to test keyboard-interactive auth""" def client_factory(): """Return an SSHClient to use to do keyboard-interactive auth""" return _KbdintClient(responses) conn, _ = yield from self.create_connection(client_factory, username=username, client_keys=None) return conn @asynctest def test_kbdint_password_auth(self): """Test keyboard-interactive server password authentication""" with (yield from self._connect_kbdint('pw', ['pw'])) as conn: pass yield from conn.wait_closed() @asynctest def test_kbdint_password_auth_multiple_responses(self): """Test multiple responses to server password authentication""" with self.assertRaises(asyncssh.DisconnectError): yield from self._connect_kbdint('pw', ['xxx', 'yyy']) @asynctest def test_kbdint_password_change(self): """Test keyboard-interactive server password change""" with self.assertRaises(asyncssh.DisconnectError): yield from self._connect_kbdint('pw', ['oldpw']) class _TestLoginTimeoutExceeded(ServerTestCase): """Unit test for login timeout""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server with a 1 second login timeout""" return (yield from cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys', login_timeout=1)) @asynctest def test_login_timeout_exceeded(self): """Test login timeout exceeded""" def client_factory(): """Return an SSHClient that delays before providing a key""" return _PublicKeyClient(['ckey'], 2) with self.assertRaises(asyncssh.DisconnectError): yield from self.create_connection(client_factory, username='ckey', client_keys=None) class _TestLoginTimeoutDisabled(ServerTestCase): """Unit test for disabled login timeout""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server with no login timeout""" return (yield from cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys', login_timeout=None)) @asynctest def test_login_timeout_disabled(self): """Test with login timeout disabled""" with (yield from self.connect(username='ckey', client_keys='ckey')) as conn: pass yield from conn.wait_closed() asyncssh-1.11.1/tests/test_editor.py000066400000000000000000000212211320320510200174510ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH line editor""" import asyncio import asyncssh from .server import ServerTestCase from .util import asynctest def _handle_session(stdin, stdout, stderr): """Accept a single line of input and echo it with a prefix""" # pylint: disable=unused-argument break_count = 0 prefix = '>>>' if stdin.channel.get_encoding() else b'>>>' data = '' if stdin.channel.get_encoding() else b'' while not stdin.at_eof(): try: data += yield from stdin.readline() except asyncssh.BreakReceived: break_count += 1 stdout.write('B') if break_count == 1: # Set twice to get coverage of when echo isn't changing stdin.channel.set_echo(False) stdin.channel.set_echo(False) elif break_count == 2: stdin.channel.set_echo(True) elif break_count == 3: stdin.channel.set_line_mode(False) else: data = 'BREAK' except asyncssh.TerminalSizeChanged: continue stdout.write(prefix + data) stdout.close() class _CheckEditor(ServerTestCase): """Utility functions for AsyncSSH line editor unit tests""" @asyncio.coroutine def check_input(self, input_data, expected_result, term_type='ansi', set_width=False): """Feed input data and compare echoed back result""" with (yield from self.connect()) as conn: process = yield from conn.create_process(term_type=term_type) process.stdin.write(input_data) if set_width: process.change_terminal_size(132, 24) process.stdin.write_eof() output_data = (yield from process.wait()).stdout idx = output_data.rfind('>>>') self.assertNotEqual(idx, -1) output_data = output_data[idx+3:] self.assertEqual(output_data, expected_result) class _TestEditor(_CheckEditor): """Unit tests for AsyncSSH line editor""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server(session_factory=_handle_session)) @asynctest def test_editor(self): """Test line editing""" tests = ( ('Simple line', 'abc\n', 'abc\r\n'), ('EOF', '\x04', ''), ('Erase left', 'abcd\x08\n', 'abc\r\n'), ('Erase left', 'abcd\x08\n', 'abc\r\n'), ('Erase left at beginning', '\x08abc\n', 'abc\r\n'), ('Erase right', 'abcd\x02\x04\n', 'abc\r\n'), ('Erase right at end', 'abc\x04\n', 'abc\r\n'), ('Erase line', 'abcdef\x15abc\n', 'abc\r\n'), ('Erase to end', 'abcdef\x02\x02\x02\x0b\n', 'abc\r\n'), ('History previous', 'abc\n\x10\n', 'abc\r\nabc\r\n'), ('History previous at top', '\x10abc\n', 'abc\r\n'), ('History next', 'a\nb\n\x10\x10\x0e\x08c\n', 'a\r\nb\r\nc\r\n'), ('History next to bottom', 'abc\n\x10\x0e\n', 'abc\r\n\r\n'), ('History next at bottom', '\x0eabc\n', 'abc\r\n'), ('Move left', 'abc\x02\n', 'abc\r\n'), ('Move left at beginning', '\x02abc\n', 'abc\r\n'), ('Move left arrow', 'abc\x1b[D\n', 'abc\r\n'), ('Move right', 'abc\x02\x06\n', 'abc\r\n'), ('Move right at end', 'abc\x06\n', 'abc\r\n'), ('Move to start', 'abc\x01\n', 'abc\r\n'), ('Move to end', 'abc\x02\x05\n', 'abc\r\n'), ('Redraw', 'abc\x12\n', 'abc\r\n'), ('Insert erased', 'abc\x15\x19\x19\n', 'abcabc\r\n'), ('Send break', '\x03\x03\x03\x03', 'BREAK'), ('Long line', 100*'*' + '\x02\x01\x05\n', 100*'*' + '\r\n'), ('Wide char wrap', 79*'*' + '\uff10\n', 79*'*' + '\uff10\r\n'), ('Unknown key', '\x07abc\n', 'abc\r\n') ) for testname, input_data, expected_result in tests: with self.subTest(testname): yield from self.check_input(input_data, expected_result) @asynctest def test_non_wrap(self): """Test line editing in non-wrap mode""" tests = ( ('Simple line', 'abc\n', 'abc\r\n'), ('Long line', 100*'*' + '\x02\x01\x05\n', 100*'*' + '\r\n'), ('Long line 2', 101*'*' + 30*'\x02' + '\x08\n', 100*'*' + '\r\n'), ('Redraw', 'abc\x12\n', 'abc\r\n') ) for testname, input_data, expected_result in tests: with self.subTest(testname): yield from self.check_input(input_data, expected_result, term_type='dumb') @asynctest def test_no_terminal(self): """Test that editor is disabled when no pseudo-terminal is requested""" yield from self.check_input('abc\n', 'abc\n', term_type=None) @asynctest def test_change_width(self): """Test changing the terminal width""" yield from self.check_input('abc\n', 'abc\r\n', set_width=True) @asynctest def test_change_width_non_wrap(self): """Test changing the terminal width when not wrapping""" yield from self.check_input('abc\n', 'abc\r\n', term_type='dumb', set_width=True) @asynctest def test_editor_echo_off(self): """Test editor with echo disabled""" with (yield from self.connect()) as conn: process = yield from conn.create_process(term_type='ansi') process.stdin.write('\x03') yield from process.stdout.readexactly(1) process.stdin.write('abcd\x08\n') process.stdin.write_eof() output_data = (yield from process.wait()).stdout self.assertEqual(output_data, '\r\n>>>abc\r\n') @asynctest def test_editor_echo_on(self): """Test editor with echo re-enabled""" with (yield from self.connect()) as conn: process = yield from conn.create_process(term_type='ansi') process.stdin.write('\x03') yield from process.stdout.readexactly(1) process.stdin.write('abc') process.stdin.write('\x03') yield from process.stdout.readexactly(1) process.stdin.write('\n') process.stdin.write_eof() output_data = (yield from process.wait()).stdout self.assertEqual(output_data, 'abc\r\n>>>abc\r\n') @asynctest def test_editor_line_mode_off(self): """Test editor with line mode disabled""" with (yield from self.connect()) as conn: process = yield from conn.create_process(term_type='ansi') process.stdin.write('\x03\x03') yield from process.stdout.readexactly(2) process.stdin.write('abc\x03') yield from process.stdout.readexactly(15) process.stdin.write('\n') process.stdin.write_eof() output_data = (yield from process.wait()).stdout self.assertEqual(output_data, 'abc\x1b[3D \x1b[3D>>>abc\r\n') class _TestEditorDisabled(_CheckEditor): """Unit tests for AsyncSSH line editor being disabled""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server(session_factory=_handle_session, line_editor=False)) @asynctest def test_editor_disabled(self): """Test that editor is disabled""" yield from self.check_input('abc\n', 'abc\n') class _TestEditorEncodingNone(_CheckEditor): """Unit tests for AsyncSSH line editor disabled due to encoding None""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server(session_factory=_handle_session, session_encoding=None)) @asynctest def test_editor_disabled_encoding_none(self): """Test that editor is disabled when encoding is None""" yield from self.check_input('abc\n', 'abc\n') @asynctest def test_change_width(self): """Test changing the terminal width""" yield from self.check_input('abc\n', 'abc\n', set_width=True) asyncssh-1.11.1/tests/test_forward.py000066400000000000000000000730471320320510200176440ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH forwarding API""" import asyncio import os import socket import sys import unittest from unittest.mock import patch import asyncssh from asyncssh.packet import String, UInt32 from asyncssh.public_key import CERT_TYPE_USER from .server import Server, ServerTestCase from .util import asynctest, echo, make_certificate def _echo_non_async(stdin, stdout, stderr=None): """Non-async version of echo callback""" conn = stdin.get_extra_info('connection') conn.create_task(echo(stdin, stdout, stderr)) def _listener(orig_host, orig_port): """Handle a forwarded TCP/IP connection""" # pylint: disable=unused-argument return echo def _listener_non_async(orig_host, orig_port): """Non-async version of handler for a forwarded TCP/IP connection""" # pylint: disable=unused-argument return _echo_non_async def _unix_listener(): """Handle a forwarded UNIX domain connection""" # pylint: disable=unused-argument return echo def _unix_listener_non_async(): """Non-async version of handler for a forwarded UNIX domain connection""" # pylint: disable=unused-argument return _echo_non_async @asyncio.coroutine def _pause(reader, writer): """Sleep to allow buffered data to build up and trigger a pause""" yield from asyncio.sleep(0.1) yield from reader.read() writer.close() class _ClientConn(asyncssh.SSHClientConnection): """Patched SSH client connection for unit testing""" @asyncio.coroutine def make_global_request(self, request, *args): """Send a global request and wait for the response""" return self._make_global_request(request, *args) class _EchoPortListener(asyncssh.SSHListener): """A TCP listener which opens a connection that echoes data""" def __init__(self, conn): self._conn = conn conn.create_task(self._open_connection()) @asyncio.coroutine def _open_connection(self): """Open a forwarded connection that echoes data""" yield from asyncio.sleep(0.1) reader, writer = yield from self._conn.open_connection('open', 65535) yield from echo(reader, writer) def close(self): """Stop listening for new connections""" pass @asyncio.coroutine def wait_closed(self): """Wait for the listener to close""" pass # pragma: no cover class _EchoPathListener(asyncssh.SSHListener): """A UNIX domain listener which opens a connection that echoes data""" def __init__(self, conn): self._conn = conn conn.create_task(self._open_connection()) @asyncio.coroutine def _open_connection(self): """Open a forwarded connection that echoes data""" yield from asyncio.sleep(0.1) reader, writer = yield from self._conn.open_unix_connection('open') yield from echo(reader, writer) def close(self): """Stop listening for new connections""" pass @asyncio.coroutine def wait_closed(self): """Wait for the listener to close""" pass # pragma: no cover class _TCPConnectionServer(Server): """Server for testing direct and forwarded TCP connections""" def connection_requested(self, dest_host, dest_port, orig_host, orig_port): """Handle a request to create a new connection""" if dest_port in {0, self._conn.get_extra_info('sockname')[1]}: return True elif dest_port == 7: return (self._conn.create_tcp_channel(), echo) elif dest_port == 8: return _pause elif dest_port == 9: self._conn.close() return (self._conn.create_tcp_channel(), echo) else: return False def server_requested(self, listen_host, listen_port): """Handle a request to create a new socket listener""" if listen_host == 'open': return _EchoPortListener(self._conn) else: return listen_host != 'fail' class _UNIXConnectionServer(Server): """Server for testing direct and forwarded UNIX domain connections""" def unix_connection_requested(self, dest_path): """Handle a request to create a new UNIX domain connection""" if dest_path == '': return True elif dest_path == '/echo': return (self._conn.create_unix_channel(), echo) else: return False def unix_server_requested(self, listen_path): """Handle a request to create a new UNIX domain listener""" if listen_path == 'open': return _EchoPathListener(self._conn) else: return listen_path != 'fail' class _CheckForwarding(ServerTestCase): """Utility functions for AsyncSSH forwarding unit tests""" @asyncio.coroutine def _check_echo_line(self, reader, writer, delay=False, encoded=False): """Check if an input line is properly echoed back""" if delay: yield from asyncio.sleep(delay) line = str(id(self)) + '\n' if not encoded: line = line.encode('utf-8') writer.write(line) yield from writer.drain() result = yield from reader.readline() writer.close() self.assertEqual(line, result) @asyncio.coroutine def _check_echo_block(self, reader, writer): """Check if a block of data is properly echoed back""" data = 4 * [1025*1024*b'\0'] writer.writelines(data) yield from writer.drain() writer.write_eof() result = yield from reader.read() yield from reader.channel.wait_closed() writer.close() self.assertEqual(b''.join(data), result) class _TestTCPForwarding(_CheckForwarding): """Unit tests for AsyncSSH TCP connection forwarding""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports UNIX connection forwarding""" return (yield from cls.create_server( _TCPConnectionServer, authorized_client_keys='authorized_keys')) @asyncio.coroutine def _check_connection(self, conn, dest_host='', dest_port=7, **kwargs): """Open a connection and test if a block of data is echoed back""" reader, writer = yield from conn.open_connection(dest_host, dest_port, *kwargs) yield from self._check_echo_block(reader, writer) @asyncio.coroutine def _check_local_connection(self, listen_port, delay=None): """Open a local connection and test if an input line is echoed back""" reader, writer = yield from asyncio.open_connection(None, listen_port) yield from self._check_echo_line(reader, writer, delay=delay) @asynctest def test_ssh_create_tunnel(self): """Test creating a tunneled SSH connection""" with (yield from self.connect()) as conn: conn2, _ = yield from conn.create_ssh_connection( None, self._server_addr, self._server_port) with conn2: yield from self._check_connection(conn2) yield from conn2.wait_closed() yield from conn.wait_closed() @asynctest def test_ssh_connect_tunnel(self): """Test connecting a tunneled SSH connection""" with (yield from self.connect()) as conn: with (yield from conn.connect_ssh(self._server_addr, self._server_port)) as conn2: yield from self._check_connection(conn2) yield from conn2.wait_closed() yield from conn.wait_closed() @asynctest def test_connection(self): """Test opening a remote connection""" with (yield from self.connect()) as conn: yield from self._check_connection(conn) yield from conn.wait_closed() @asynctest def test_connection_failure(self): """Test failure in opening a remote connection""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_connection('', 0) yield from conn.wait_closed() @asynctest def test_connection_rejected(self): """Test rejection in opening a remote connection""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_connection('fail', 0) yield from conn.wait_closed() @asynctest def test_connection_not_permitted(self): """Test permission denied in opening a remote connection""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_connection('', 7) yield from conn.wait_closed() @asynctest def test_connection_not_permitted_open(self): """Test open permission denied in opening a remote connection""" with (yield from self.connect(username='ckey', client_keys=['ckey'])) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_connection('fail', 7) yield from conn.wait_closed() @asynctest def test_connection_invalid_unicode(self): """Test opening a connection with invalid Unicode in host""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_connection(b'\xff', 0) yield from conn.wait_closed() @asynctest def test_server(self): """Test creating a remote listener""" with (yield from self.connect()) as conn: listener = yield from conn.start_server(_listener, '', 0) yield from self._check_local_connection(listener.get_port()) listener.close() listener.close() yield from listener.wait_closed() listener.close() yield from conn.wait_closed() @asynctest def test_server_open(self): """Test creating a remote listener which uses open_connection""" def new_connection(reader, writer): """Handle a forwarded TCP/IP connection""" waiter.set_result((reader, writer)) def handler_factory(orig_host, orig_port): """Handle all connections using new_connection""" # pylint: disable=unused-argument return new_connection with (yield from self.connect()) as conn: waiter = asyncio.Future(loop=self.loop) yield from conn.start_server(handler_factory, 'open', 0) reader, writer = yield from waiter yield from self._check_echo_line(reader, writer) # Clean up the listener during connection close yield from conn.wait_closed() @asynctest def test_server_non_async(self): """Test creating a remote listener using non-async handler""" with (yield from self.connect()) as conn: listener = yield from conn.start_server(_listener_non_async, '', 0) yield from self._check_local_connection(listener.get_port()) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @asynctest def test_server_failure(self): """Test failure in creating a remote listener""" with (yield from self.connect()) as conn: listener = yield from conn.start_server(_listener, 'fail', 0) self.assertIsNone(listener) yield from conn.wait_closed() @asynctest def test_forward_local_port(self): """Test forwarding of a local port""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_port('', 0, '', 7) yield from self._check_local_connection(listener.get_port(), delay=0.1) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @asynctest def test_forward_local_port_pause(self): """Test pause during forwarding of a local port""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_port('', 0, '', 8) listen_port = listener.get_port() reader, writer = yield from asyncio.open_connection(None, listen_port) writer.write(4*1024*1024*b'\0') writer.write_eof() yield from reader.read() writer.close() listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @asynctest def test_forward_local_port_failure(self): """Test failure in forwarding a local port""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_port('', 0, '', 65535) listen_port = listener.get_port() reader, writer = yield from asyncio.open_connection(None, listen_port) self.assertEqual((yield from reader.read()), b'') writer.close() listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @unittest.skipIf(sys.platform == 'win32', 'skip dual-stack tests on Windows') @asynctest def test_forward_bind_error_ipv4(self): """Test error binding a local forwarding port""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_port('0.0.0.0', 0, '', 7) with self.assertRaises(OSError): yield from conn.forward_local_port(None, listener.get_port(), '', 7) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @unittest.skipIf(sys.platform == 'win32', 'skip dual-stack tests on Windows') @asynctest def test_forward_bind_error_ipv6(self): """Test error binding a local forwarding port""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_port('::', 0, '', 7) with self.assertRaises(OSError): yield from conn.forward_local_port(None, listener.get_port(), '', 7) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @asynctest def test_forward_connect_error(self): """Test error connecting a local forwarding port""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_port('', 0, '', 0) listen_port = listener.get_port() reader, writer = yield from asyncio.open_connection(None, listen_port) self.assertEqual((yield from reader.read()), b'') writer.close() listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @asynctest def test_forward_immediate_eof(self): """Test getting EOF before forwarded connection is fully open""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_port('', 0, '', 7) listen_port = listener.get_port() _, writer = yield from asyncio.open_connection(None, listen_port) writer.close() yield from asyncio.sleep(0.1) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @asynctest def test_forward_remote_port(self): """Test forwarding of a remote port""" server = yield from asyncio.start_server(echo, None, 0, family=socket.AF_INET) server_port = server.sockets[0].getsockname()[1] with (yield from self.connect()) as conn: listener = yield from conn.forward_remote_port('', 0, '', server_port) yield from self._check_local_connection(listener.get_port()) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() server.close() yield from server.wait_closed() @asynctest def test_forward_remote_specific_port(self): """Test forwarding of a specific remote port""" server = yield from asyncio.start_server(echo, None, 0, family=socket.AF_INET) server_port = server.sockets[0].getsockname()[1] sock = socket.socket() sock.bind(('', 0)) remote_port = sock.getsockname()[1] sock.close() with (yield from self.connect()) as conn: listener = yield from conn.forward_remote_port('', remote_port, '', server_port) yield from self._check_local_connection(listener.get_port()) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() server.close() yield from server.wait_closed() @asynctest def test_forward_remote_port_failure(self): """Test failure of forwarding a remote port""" with (yield from self.connect()) as conn: listener = yield from conn.forward_remote_port('', 65536, '', 0) self.assertIsNone(listener) yield from conn.wait_closed() @asynctest def test_forward_remote_port_not_permitted(self): """Test permission denied in forwarding of a remote port""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: listener = yield from conn.forward_remote_port('', 0, '', 0) self.assertIsNone(listener) yield from conn.wait_closed() @asynctest def test_forward_remote_port_invalid_unicode(self): """Test TCP/IP forwarding with invalid Unicode in host""" with (yield from self.connect()) as conn: listener = yield from conn.forward_remote_port(b'\xff', 0, '', 0) self.assertIsNone(listener) yield from conn.wait_closed() @asynctest def test_cancel_forward_remote_port_invalid_unicode(self): """Test canceling TCP/IP forwarding with invalid Unicode in host""" with patch('asyncssh.connection.SSHClientConnection', _ClientConn): with (yield from self.connect()) as conn: pkttype, _ = yield from conn.make_global_request( b'cancel-tcpip-forward', String(b'\xff'), UInt32(0)) self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE) yield from conn.wait_closed() @asynctest def test_add_channel_after_close(self): """Test opening a connection after a close""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_connection('', 9) yield from conn.wait_closed() @asynctest def test_multiple_global_requests(self): """Test sending multiple global requests in parallel""" with (yield from self.connect()) as conn: listeners = yield from asyncio.gather( conn.forward_remote_port('', 0, '', 7), conn.forward_remote_port('', 0, '', 7)) for listener in listeners: listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') class _TestUNIXForwarding(_CheckForwarding): """Unit tests for AsyncSSH UNIX connection forwarding""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server which supports UNIX connection forwarding""" return (yield from cls.create_server( _UNIXConnectionServer, authorized_client_keys='authorized_keys')) @asyncio.coroutine def _check_unix_connection(self, conn, dest_path='/echo', **kwargs): """Open a UNIX connection and test if an input line is echoed back""" reader, writer = yield from conn.open_unix_connection(dest_path, encoding='utf-8', *kwargs) yield from self._check_echo_line(reader, writer, encoded=True) @asyncio.coroutine def _check_local_unix_connection(self, listen_path): """Open a local connection and test if an input line is echoed back""" # pylint doesn't think open_unix_connection exists # pylint: disable=no-member reader, writer = yield from asyncio.open_unix_connection(listen_path) # pylint: enable=no-member yield from self._check_echo_line(reader, writer) @asynctest def test_unix_connection(self): """Test opening a remote UNIX connection""" with (yield from self.connect()) as conn: yield from self._check_unix_connection(conn) yield from conn.wait_closed() @asynctest def test_unix_connection_failure(self): """Test failure in opening a remote UNIX connection""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_unix_connection('') yield from conn.wait_closed() @asynctest def test_unix_connection_rejected(self): """Test rejection in opening a remote UNIX connection""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_unix_connection('/fail') yield from conn.wait_closed() @asynctest def test_unix_connection_not_permitted(self): """Test permission denied in opening a remote UNIX connection""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_unix_connection('/echo') yield from conn.wait_closed() @asynctest def test_unix_connection_invalid_unicode(self): """Test opening a UNIX connection with invalid Unicode in path""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_unix_connection(b'\xff') yield from conn.wait_closed() @asynctest def test_unix_server(self): """Test creating a remote UNIX listener""" path = os.path.abspath('echo') with (yield from self.connect()) as conn: listener = yield from conn.start_unix_server(_unix_listener, path) yield from self._check_local_unix_connection('echo') listener.close() listener.close() yield from listener.wait_closed() listener.close() yield from conn.wait_closed() os.remove('echo') @asynctest def test_unix_server_open(self): """Test creating a UNIX listener which uses open_unix_connection""" def new_connection(reader, writer): """Handle a forwarded UNIX domain connection""" waiter.set_result((reader, writer)) def handler_factory(): """Handle all connections using new_connection""" # pylint: disable=unused-argument return new_connection with (yield from self.connect()) as conn: waiter = asyncio.Future(loop=self.loop) listener = yield from conn.start_unix_server(handler_factory, 'open') reader, writer = yield from waiter yield from self._check_echo_line(reader, writer) listener.close() yield from listener.wait_closed() yield from conn.wait_closed() @asynctest def test_unix_server_non_async(self): """Test creating a remote UNIX listener using non-async handler""" path = os.path.abspath('echo') with (yield from self.connect()) as conn: listener = yield from conn.start_unix_server( _unix_listener_non_async, path) yield from self._check_local_unix_connection('echo') listener.close() yield from listener.wait_closed() yield from conn.wait_closed() os.remove('echo') @asynctest def test_unix_server_failure(self): """Test failure in creating a remote UNIX listener""" with (yield from self.connect()) as conn: listener = yield from conn.start_unix_server(_unix_listener, 'fail') self.assertIsNone(listener) yield from conn.wait_closed() @asynctest def test_forward_local_path(self): """Test forwarding of a local UNIX domain path""" with (yield from self.connect()) as conn: listener = yield from conn.forward_local_path('local', '/echo') yield from self._check_local_unix_connection('local') listener.close() yield from listener.wait_closed() yield from conn.wait_closed() os.remove('local') @asynctest def test_forward_remote_path(self): """Test forwarding of a remote UNIX domain path""" # pylint doesn't think start_unix_server exists # pylint: disable=no-member server = yield from asyncio.start_unix_server(echo, 'local') # pylint: enable=no-member path = os.path.abspath('echo') with (yield from self.connect()) as conn: listener = yield from conn.forward_remote_path(path, 'local') yield from self._check_local_unix_connection('echo') listener.close() yield from listener.wait_closed() yield from conn.wait_closed() server.close() yield from server.wait_closed() os.remove('echo') os.remove('local') @asynctest def test_forward_remote_path_failure(self): """Test failure of forwarding a remote UNIX domain path""" open('echo', 'w').close() path = os.path.abspath('echo') with (yield from self.connect()) as conn: listener = yield from conn.forward_remote_path(path, 'local') self.assertIsNone(listener) yield from conn.wait_closed() os.remove('echo') @asynctest def test_forward_remote_path_not_permitted(self): """Test permission denied in forwarding a remote UNIX domain path""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: listener = yield from conn.forward_remote_path('', 'local') self.assertIsNone(listener) yield from conn.wait_closed() @asynctest def test_forward_remote_path_invalid_unicode(self): """Test forwarding a UNIX domain path with invalid Unicode in it""" with (yield from self.connect()) as conn: listener = yield from conn.forward_remote_path(b'\xff', 'local') self.assertIsNone(listener) yield from conn.wait_closed() @asynctest def test_cancel_forward_remote_path_invalid_unicode(self): """Test canceling UNIX forwarding with invalid Unicode in path""" with patch('asyncssh.connection.SSHClientConnection', _ClientConn): with (yield from self.connect()) as conn: pkttype, _ = yield from conn.make_global_request( b'cancel-streamlocal-forward@openssh.com', String(b'\xff')) self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE) yield from conn.wait_closed() asyncssh-1.11.1/tests/test_kex.py000066400000000000000000000422621320320510200167620ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for key exchange""" import asyncio import unittest from hashlib import sha1 import asyncssh from asyncssh.dh import MSG_KEXDH_INIT, MSG_KEXDH_REPLY from asyncssh.dh import _KexDHGex, MSG_KEX_DH_GEX_REQUEST, MSG_KEX_DH_GEX_GROUP from asyncssh.dh import MSG_KEX_DH_GEX_INIT, MSG_KEX_DH_GEX_REPLY from asyncssh.dh import MSG_KEXGSS_INIT, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR from asyncssh.ecdh import MSG_KEX_ECDH_INIT, MSG_KEX_ECDH_REPLY from asyncssh.gss import GSSClient, GSSServer from asyncssh.kex import register_kex_alg, get_kex_algs, get_kex from asyncssh.misc import DisconnectError from asyncssh.packet import SSHPacket, Boolean, Byte, MPInt, String from asyncssh.public_key import SSHLocalKeyPair, decode_ssh_public_key from .util import asynctest, gss_available, patch_gss from .util import AsyncTestCase, ConnectionStub # Short variable names are used here, matching names in the specs # pylint: disable=invalid-name class _KexConnectionStub(ConnectionStub): """Connection stub class to test key exchange""" def __init__(self, alg, gss, peer, server=False): super().__init__(peer, server) self._gss = gss self._key_waiter = asyncio.Future() self._kex = get_kex(self, alg) if self.is_client(): self._kex.start() def connection_lost(self, exc): """Handle the closing of a connection""" raise NotImplementedError def enable_gss_kex_auth(self): """Ignore request to enable GSS key exchange authentication""" pass def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() self._kex.process_packet(pkttype, packet) def get_hash_prefix(self): """Return the bytes used in calculating unique connection hashes""" # pylint: disable=no-self-use return b'prefix' def send_newkeys(self, k, h): """Handle a request to send a new keys message""" self._key_waiter.set_result(self._kex.compute_key(k, h, b'A', h, 128)) def get_key(self): """Return generated key data""" return (yield from self._key_waiter) def get_gss_context(self): """Return the GSS context associated with this connection""" return self._gss def simulate_dh_init(self, e): """Simulate receiving a DH init packet""" self.process_packet(Byte(MSG_KEXDH_INIT) + MPInt(e)) def simulate_dh_reply(self, host_key_data, f, sig): """Simulate receiving a DH reply packet""" self.process_packet(b''.join((Byte(MSG_KEXDH_REPLY), String(host_key_data), MPInt(f), String(sig)))) def simulate_dh_gex_group(self, p, g): """Simulate receiving a DH GEX group packet""" self.process_packet(Byte(MSG_KEX_DH_GEX_GROUP) + MPInt(p) + MPInt(g)) def simulate_dh_gex_init(self, e): """Simulate receiving a DH GEX init packet""" self.process_packet(Byte(MSG_KEX_DH_GEX_INIT) + MPInt(e)) def simulate_dh_gex_reply(self, host_key_data, f, sig): """Simulate receiving a DH GEX reply packet""" self.process_packet(b''.join((Byte(MSG_KEX_DH_GEX_REPLY), String(host_key_data), MPInt(f), String(sig)))) def simulate_gss_complete(self, f, sig): """Simulate receiving a GSS complete packet""" self.process_packet(b''.join((Byte(MSG_KEXGSS_COMPLETE), MPInt(f), String(sig), Boolean(False)))) def simulate_ecdh_init(self, client_pub): """Simulate receiving an ECDH init packet""" self.process_packet(Byte(MSG_KEX_ECDH_INIT) + String(client_pub)) def simulate_ecdh_reply(self, host_key_data, server_pub, sig): """Simulate receiving ab ECDH reply packet""" self.process_packet(b''.join((Byte(MSG_KEX_ECDH_REPLY), String(host_key_data), String(server_pub), String(sig)))) class _KexClientStub(_KexConnectionStub): """Stub class for client connection""" @classmethod def make_pair(cls, alg, gss_host=None): """Make a client and server connection pair to test key exchange""" client_conn = cls(alg, gss_host) return client_conn, client_conn.get_peer() def __init__(self, alg, gss_host): server_conn = _KexServerStub(alg, gss_host, self) try: if gss_host: gss = GSSClient(gss_host, 'delegate' in gss_host) else: gss = None super().__init__(alg, gss, server_conn) except DisconnectError: server_conn.close() raise def connection_lost(self, exc): """Handle the closing of a connection""" if exc and not self._key_waiter.done(): self._key_waiter.set_exception(exc) self.close() def validate_server_host_key(self, host_key_data): """Validate and return the server's host key""" # pylint: disable=no-self-use return decode_ssh_public_key(host_key_data) class _KexServerStub(_KexConnectionStub): """Stub class for server connection""" def __init__(self, alg, gss_host, peer): gss = GSSServer(gss_host) if gss_host else None super().__init__(alg, gss, peer, True) if gss_host and 'no_host_key' in gss_host: self._server_host_key = None else: priv_key = asyncssh.generate_private_key('ssh-rsa') self._server_host_key = SSHLocalKeyPair(priv_key) def connection_lost(self, exc): """Handle the closing of a connection""" if self._peer: self._peer.connection_lost(exc) self.close() def get_server_host_key(self): """Return the server host key""" return self._server_host_key @patch_gss class _TestKex(AsyncTestCase): """Unit tests for kex module""" @asyncio.coroutine def _check_kex(self, alg, gss_host=None): """Unit test key exchange""" client_conn, server_conn = _KexClientStub.make_pair(alg, gss_host) try: self.assertEqual((yield from client_conn.get_key()), (yield from server_conn.get_key())) finally: client_conn.close() server_conn.close() @asynctest def test_key_exchange_algs(self): """Unit test key exchange algorithms""" for alg in get_kex_algs(): with self.subTest(alg=alg): if alg.startswith(b'gss-'): if gss_available: # pragma: no branch yield from self._check_kex(alg + b'-mech', '1') else: yield from self._check_kex(alg) if gss_available: # pragma: no branch for steps in range(4): with self.subTest('GSS key exchange', steps=steps): yield from self._check_kex(b'gss-group1-sha1-mech', str(steps)) with self.subTest('GSS with credential delegation'): yield from self._check_kex(b'gss-group1-sha1-mech', '1,delegate') with self.subTest('GSS with no host key'): yield from self._check_kex(b'gss-group1-sha1-mech', '1,no_host_key') with self.subTest('GSS with full host principal'): yield from self._check_kex(b'gss-group1-sha1-mech', 'host/1@TEST') @asynctest def test_dh_gex_old(self): """Unit test old DH group exchange request""" register_kex_alg(b'dh-gex-sha1-1024', _KexDHGex, sha1, 1024) register_kex_alg(b'dh-gex-sha1-2048', _KexDHGex, sha1, 2048) for size in (b'1024', b'2048'): with self.subTest('Old DH group exchange', size=size): yield from self._check_kex(b'dh-gex-sha1-' + size) @asynctest def test_dh_gex(self): """Unit test old DH group exchange request""" register_kex_alg(b'dh-gex-sha1-1024-1536', _KexDHGex, sha1, 1024, 1536) register_kex_alg(b'dh-gex-sha1-1536-3072', _KexDHGex, sha1, 1536, 3072) register_kex_alg(b'dh-gex-sha1-2560-2560', _KexDHGex, sha1, 2560, 2560) register_kex_alg(b'dh-gex-sha1-2560-4096', _KexDHGex, sha1, 2560, 4096) register_kex_alg(b'dh-gex-sha1-9216-9216', _KexDHGex, sha1, 9216, 9216) for size in (b'1024-1536', b'1536-3072', b'2560-2560', b'2560-4096', b'9216-9216'): with self.subTest('Old DH group exchange', size=size): yield from self._check_kex(b'dh-gex-sha1-' + size) @asynctest def test_dh_errors(self): """Unit test error conditions in DH key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'diffie-hellman-group14-sha1') host_key = server_conn.get_server_host_key() with self.subTest('Init sent to client'): with self.assertRaises(DisconnectError): client_conn.process_packet(Byte(MSG_KEXDH_INIT)) with self.subTest('Reply sent to server'): with self.assertRaises(DisconnectError): server_conn.process_packet(Byte(MSG_KEXDH_REPLY)) with self.subTest('Invalid e value'): with self.assertRaises(DisconnectError): server_conn.simulate_dh_init(0) with self.subTest('Invalid f value'): with self.assertRaises(DisconnectError): client_conn.simulate_dh_reply(host_key.public_data, 0, b'') with self.subTest('Invalid signature'): with self.assertRaises(DisconnectError): client_conn.simulate_dh_reply(host_key.public_data, 1, b'') client_conn.close() server_conn.close() @asynctest def test_dh_gex_errors(self): """Unit test error conditions in DH group exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'diffie-hellman-group-exchange-sha1') with self.subTest('Request sent to client'): with self.assertRaises(DisconnectError): client_conn.process_packet(Byte(MSG_KEX_DH_GEX_REQUEST)) with self.subTest('Group sent to server'): with self.assertRaises(DisconnectError): server_conn.simulate_dh_gex_group(1, 2) with self.subTest('Init sent to client'): with self.assertRaises(DisconnectError): client_conn.simulate_dh_gex_init(1) with self.subTest('Init sent before group'): with self.assertRaises(DisconnectError): server_conn.simulate_dh_gex_init(1) with self.subTest('Reply sent to server'): with self.assertRaises(DisconnectError): server_conn.simulate_dh_gex_reply(b'', 1, b'') with self.subTest('Reply sent before group'): with self.assertRaises(DisconnectError): client_conn.simulate_dh_gex_reply(b'', 1, b'') client_conn.close() server_conn.close() @unittest.skipUnless(gss_available, 'GSS not available') @asynctest def test_gss_errors(self): """Unit test error conditions in GSS key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'gss-group1-sha1-mech', '3') with self.subTest('Init sent to client'): with self.assertRaises(DisconnectError): client_conn.process_packet(Byte(MSG_KEXGSS_INIT)) with self.subTest('Complete sent to server'): with self.assertRaises(DisconnectError): server_conn.process_packet(Byte(MSG_KEXGSS_COMPLETE)) with self.subTest('Exchange failed to complete'): with self.assertRaises(DisconnectError): client_conn.simulate_gss_complete(1, b'succeed') with self.subTest('Error sent to server'): with self.assertRaises(DisconnectError): server_conn.process_packet(Byte(MSG_KEXGSS_ERROR)) client_conn.close() server_conn.close() with self.subTest('Signature verification failure'): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', '0,fail') with self.subTest('Empty token in init'): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', '0,empty_init') with self.subTest('Empty token in continue'): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', '1,empty_continue') with self.subTest('Token after complete'): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', '0,continue_token') for steps in range(2): with self.subTest('Token after complete', steps=steps): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', str(steps) + ',extra_token') with self.subTest('Context not secure'): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', '1,no_server_integrity') with self.subTest('GSS error'): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', '1,step_error') with self.subTest('GSS error with error token'): with self.assertRaises(DisconnectError): yield from self._check_kex(b'gss-group1-sha1-mech', '1,step_error,errtok') @asynctest def test_ecdh_errors(self): """Unit test error conditions in ECDH key exchange""" try: from asyncssh.crypto import ECDH except ImportError: # pragma: no cover return client_conn, server_conn = \ _KexClientStub.make_pair(b'ecdh-sha2-nistp256') with self.subTest('Init sent to client'): with self.assertRaises(DisconnectError): client_conn.simulate_ecdh_init(b'') with self.subTest('Invalid client public key'): with self.assertRaises(DisconnectError): server_conn.simulate_ecdh_init(b'') with self.subTest('Reply sent to server'): with self.assertRaises(DisconnectError): server_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server host key'): with self.assertRaises(DisconnectError): client_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server public key'): with self.assertRaises(DisconnectError): host_key = server_conn.get_server_host_key() client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(DisconnectError): host_key = server_conn.get_server_host_key() server_pub = ECDH(b'nistp256').get_public() client_conn.simulate_ecdh_reply(host_key.public_data, server_pub, b'') client_conn.close() server_conn.close() @asynctest def test_curve25519dh_errors(self): """Unit test error conditions in Curve25519DH key exchange""" try: from asyncssh.crypto import Curve25519DH except ImportError: # pragma: no cover return client_conn, server_conn = \ _KexClientStub.make_pair(b'curve25519-sha256@libssh.org') with self.subTest('Invalid client public key'): with self.assertRaises(DisconnectError): server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid server public key'): with self.assertRaises(DisconnectError): host_key = server_conn.get_server_host_key() client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(DisconnectError): host_key = server_conn.get_server_host_key() server_pub = Curve25519DH().get_public() client_conn.simulate_ecdh_reply(host_key.public_data, server_pub, b'') client_conn.close() server_conn.close() asyncssh-1.11.1/tests/test_known_hosts.py000066400000000000000000000245051320320510200205470ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for matching against known_hosts file""" import binascii import hashlib import hmac import os import asyncssh from .util import TempDirTestCase, x509_available if x509_available: # pragma: no branch from asyncssh.crypto import X509NamePattern def _hash(host): """Return a hashed version of a hostname in a known_hosts file""" salt = os.urandom(20) hosthash = hmac.new(salt, host.encode(), hashlib.sha1).digest() entry = b'|'.join((b'', b'1', binascii.b2a_base64(salt)[:-1], binascii.b2a_base64(hosthash)[:-1])) return entry.decode() class _TestKnownHosts(TempDirTestCase): """Unit tests for known_hosts module""" keylists = ([], [], [], [], [], [], []) imported_keylists = ([], [], [], [], [], [], []) @classmethod def setUpClass(cls): """Create public keys needed for test""" super().setUpClass() for keylist, imported_keylist in zip(cls.keylists[:3], cls.imported_keylists[:3]): for _ in range(3): key = asyncssh.generate_private_key('ssh-rsa') keylist.append(key.export_public_key().decode('ascii')) imported_keylist.append(key.convert_to_public()) if x509_available: # pragma: no branch for keylist, imported_keylist in zip(cls.keylists[3:5], cls.imported_keylists[3:5]): for _ in range(2): key = asyncssh.generate_private_key('ssh-rsa') cert = key.generate_x509_user_certificate(key, 'OU=user', 'OU=user') keylist.append( cert.export_certificate('openssh').decode('ascii')) imported_keylist.append(cert) for keylist, imported_keylist in zip(cls.keylists[5:], cls.imported_keylists[5:]): for name in ('OU=user', 'OU=revoked'): keylist.append('x509v3-ssh-rsa subject=' + name + '\n') imported_keylist.append(X509NamePattern(name)) def check_match(self, known_hosts, results=None, host='host', addr='1.2.3.4', port=22): """Check the result of calling match_known_hosts""" if results: results = tuple([kl[r] for r in result] for kl, result in zip(self.imported_keylists, results)) matches = asyncssh.match_known_hosts(known_hosts, host, addr, port) self.assertEqual(matches, results) def check_hosts(self, patlists, results=None, host='host', addr='1.2.3.4', port=22, from_file=False, from_bytes=False, as_callable=False, as_tuple=False): """Check a known_hosts file built from the specified patterns""" def call_match(host, addr, port): """Test passing callable as known_hosts""" return asyncssh.match_known_hosts(_known_hosts, host, addr, port) prefixes = ('', '@cert-authority ', '@revoked ', '', '@revoked ', '', '@revoked ') known_hosts = '# Comment line\n # Comment line with whitespace\n\n' for prefix, patlist, keys in zip(prefixes, patlists, self.keylists): for pattern, key in zip(patlist, keys): known_hosts += '%s%s %s' % (prefix, pattern, key) if from_file: with open('known_hosts', 'w') as f: f.write(known_hosts) known_hosts = 'known_hosts' elif from_bytes: known_hosts = known_hosts.encode() elif as_callable: _known_hosts = asyncssh.import_known_hosts(known_hosts) known_hosts = call_match elif as_tuple: known_hosts = asyncssh.import_known_hosts(known_hosts) known_hosts = asyncssh.match_known_hosts(known_hosts, host, addr, port) else: known_hosts = asyncssh.import_known_hosts(known_hosts) return self.check_match(known_hosts, results, host, addr, port) def test_match(self): """Test known host matching""" matches = ( ('Empty file', ([], [], [], [], [], [], []), ([], [], [], [], [], [], [])), ('Exact host and port', (['[host]:22'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Exact host', (['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Exact host CA', ([], ['host'], [], [], [], [], []), ([], [0], [], [], [], [], [])), ('Exact host revoked', ([], [], ['host'], [], [], [], []), ([], [], [0], [], [], [], [])), ('Multiple exact', (['host'], ['host'], [], [], [], [], []), ([0], [0], [], [], [], [], [])), ('Wildcard host', (['hos*'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Mismatched port', (['[host]:23'], [], [], [], [], [], []), ([], [], [], [], [], [], [])), ('Negative host', (['hos*,!host'], [], [], [], [], [], []), ([], [], [], [], [], [], [])), ('Exact addr and port', (['[1.2.3.4]:22'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Exact addr', (['1.2.3.4'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Subnet', (['1.2.3.0/24'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Negative addr', (['1.2.3.0/24,!1.2.3.4', [], [], []], [], [], []), ([], [], [], [], [], [], [])), ('Hashed host', ([_hash('host')], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Hashed addr', ([_hash('1.2.3.4')], [], [], [], [], [], []), ([0], [], [], [], [], [], [])) ) if x509_available: # pragma: no branch matches += ( ('Exact host X.509', ([], [], [], ['host'], [], [], []), ([], [], [], [0], [], [], [])), ('Exact host X.509 revoked', ([], [], [], [], ['host'], [], []), ([], [], [], [], [0], [], [])), ('Exact host subject', ([], [], [], [], [], ['host'], []), ([], [], [], [], [], [0], [])), ('Exact host revoked subject', ([], [], [], [], [], [], ['host']), ([], [], [], [], [], [], [0])), ) for testname, patlists, result in matches: with self.subTest(testname): self.check_hosts(patlists, result) def test_no_addr(self): """Test match without providing addr""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), addr=None) self.check_hosts((['1.2.3.4'], [], [], [], [], [], []), ([], [], [], [], [], [], []), addr=None) def test_no_port(self): """Test match without providing port""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), port=None) self.check_hosts((['[host]:22'], [], [], [], [], [], []), ([], [], [], [], [], [], []), port=None) def test_no_match(self): """Test for cases where no match is found""" no_match = (([], [], [], [], [], [], []), (['host1', 'host2'], [], [], [], [], [], []), (['2.3.4.5', '3.4.5.6'], [], [], [], [], [], []), (['[host]:2222', '[host]:22222'], [], [], [], [], [], [])) for patlists in no_match: self.check_hosts(patlists, ([], [], [], [], [], [], [])) def test_missing_key(self): """Test for line with missing key data""" with self.assertRaises(ValueError): self.check_match(b'xxx\n') def test_missing_key_with_tag(self): """Test for line with tag with missing key data""" with self.assertRaises(ValueError): self.check_match(b'@cert-authority xxx\n') def test_invalid_key(self): """Test for line with invaid key""" self.check_match(b'xxx yyy\n', ([], [], [], [], [], [], [])) def test_invalid_marker(self): """Test for line with invaid marker""" with self.assertRaises(ValueError): self.check_match(b'@xxx yyy zzz\n') def test_incomplete_hash(self): """Test for line with incomplete host hash""" with self.assertRaises(ValueError): self.check_hosts((['|1|aaaa'], [], [], [], [], [], [], [], [])) def test_invalid_hash(self): """Test for line with invalid host hash""" with self.assertRaises(ValueError): self.check_hosts((['|1|aaa'], [], [], [], [], [], [], [], [])) def test_unknown_hash_type(self): """Test for line with unknown host hash type""" with self.assertRaises(ValueError): self.check_hosts((['|2|aaaa|'], [], [], [], [], [], [], [], [])) def test_file(self): """Test match against file""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), from_file=True) def test_bytes(self): """Test match against byte string""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), from_bytes=True) def test_callable(self): """Test match using callable""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), as_callable=True) def test_tuple(self): """Test passing already constructed tuple of keys""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), as_tuple=True) asyncssh-1.11.1/tests/test_mac.py000066400000000000000000000035441320320510200167330ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for message authentication""" import os import unittest from asyncssh.mac import get_mac_algs, get_mac_params, get_mac class _TestMAC(unittest.TestCase): """Unit tests for mac module""" def test_mac_algs(self): """Unit test MAC algorithms""" for alg in get_mac_algs(): with self.subTest(alg=alg): keysize, _, _ = get_mac_params(alg) key = os.urandom(keysize) packet = os.urandom(256) enc_mac = get_mac(alg, key) dec_mac = get_mac(alg, key) badpacket = bytearray(packet) badpacket[-1] ^= 0xff mac = enc_mac.sign(0, packet) badmac = bytearray(mac) badmac[-1] ^= 0xff self.assertTrue(dec_mac.verify(0, packet, mac)) self.assertFalse(dec_mac.verify(0, bytes(badpacket), mac)) self.assertFalse(dec_mac.verify(0, packet, bytes(badmac))) def test_umac_wrapper(self): """Unit test some unused parts of the UMAC wrapper code""" try: from asyncssh.crypto import umac32 except ImportError: # pragma: no cover self.skipTest('umac not available') key = os.urandom(16) mac1 = umac32(key) mac1.update(b'test') mac2 = mac1.copy() mac1.update(b'123') mac2.update(b'123') self.assertEqual(mac1.hexdigest(), mac2.hexdigest()) asyncssh-1.11.1/tests/test_packet.py000066400000000000000000000160551320320510200174430ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for SSH packet encoding and decoding""" import codecs import unittest from asyncssh.packet import Byte, Boolean, UInt32, UInt64, String, MPInt from asyncssh.packet import NameList, PacketDecodeError, SSHPacket from asyncssh.packet import SSHPacketHandler class _TestPacket(unittest.TestCase): """Unit tests for SSH packet module""" # pylint: disable=bad-whitespace tests = [ (Byte, SSHPacket.get_byte, [ (0, '00'), (127, '7f'), (128, '80'), (255, 'ff') ]), (Boolean, SSHPacket.get_boolean, [ (False, '00'), (True, '01') ]), (UInt32, SSHPacket.get_uint32, [ (0, '00000000'), (256, '00000100'), (0x12345678, '12345678'), (0x7fffffff, '7fffffff'), (0x80000000, '80000000'), (0xffffffff, 'ffffffff') ]), (UInt64, SSHPacket.get_uint64, [ (0, '0000000000000000'), (256, '0000000000000100'), (0x123456789abcdef0, '123456789abcdef0'), (0x7fffffffffffffff, '7fffffffffffffff'), (0x8000000000000000, '8000000000000000'), (0xffffffffffffffff, 'ffffffffffffffff') ]), (String, SSHPacket.get_string, [ (b'', '00000000'), (b'foo', '00000003666f6f'), (1024*b'\xff', '00000400' + 1024*'ff') ]), (MPInt, SSHPacket.get_mpint, [ (0, '00000000'), (1, '0000000101'), (127, '000000017f'), (128, '000000020080'), (32767, '000000027fff'), (32768, '00000003008000'), (0x123456789abcdef01234, '0000000a123456789abcdef01234'), (-1, '00000001ff'), (-128, '0000000180'), (-129, '00000002ff7f'), (-32768, '000000028000'), (-32769, '00000003ff7fff'), (-0xdeadbeef, '00000005ff21524111') ]), (NameList, SSHPacket.get_namelist, [ ([], '00000000'), ([b'foo'], '00000003666f6f'), ([b'foo', b'bar'], '00000007666f6f2c626172') ]) ] encode_errors = [ (Byte, -1, ValueError), (Byte, 256, ValueError), (Byte, 'a', TypeError), (UInt32, None, AttributeError), (UInt32, -1, OverflowError), (UInt32, 0x100000000, OverflowError), (UInt64, None, AttributeError), (UInt64, -1, OverflowError), (UInt64, 0x10000000000000000, OverflowError), (String, None, TypeError), (String, True, TypeError), (String, 0, TypeError), (MPInt, None, AttributeError), (MPInt, '', AttributeError), (MPInt, [], AttributeError), (NameList, None, TypeError), (NameList, 'xxx', TypeError) ] decode_errors = [ (SSHPacket.get_byte, ''), (SSHPacket.get_byte, '1234'), (SSHPacket.get_boolean, ''), (SSHPacket.get_boolean, '1234'), (SSHPacket.get_uint32, '123456'), (SSHPacket.get_uint32, '1234567890'), (SSHPacket.get_uint64, '12345678'), (SSHPacket.get_uint64, '123456789abcdef012'), (SSHPacket.get_string, '123456'), (SSHPacket.get_string, '12345678'), (SSHPacket.get_string, '000000011234') ] # pylint: enable=bad-whitespace def test_packet(self): """Unit test SSH packet module""" for encode, decode, values in self.tests: for value, data in values: data = codecs.decode(data, 'hex') with self.subTest(msg='encode', value=value): self.assertEqual(encode(value), data) with self.subTest(msg='decode', data=data): packet = SSHPacket(data) decoded_value = decode(packet) packet.check_end() self.assertEqual(decoded_value, value) self.assertEqual(packet.get_consumed_payload(), data) self.assertEqual(packet.get_remaining_payload(), b'') for encode, value, exc in self.encode_errors: with self.subTest(msg='encode error', encode=encode, value=value): with self.assertRaises(exc): encode(value) for decode, data in self.decode_errors: with self.subTest(msg='decode error', data=data): with self.assertRaises(PacketDecodeError): packet = SSHPacket(codecs.decode(data, 'hex')) decode(packet) packet.check_end() def test_unicode(self): """Unit test encoding of UTF-8 string""" self.assertEqual(String('\u2000'), b'\x00\x00\x00\x03\xe2\x80\x80') def test_handler(self): """Unit test SSH packet handler""" class _TestPacketHandler(SSHPacketHandler): """Class for unit testing SSHPacketHandler""" def _handler1(self, pkttype, packet): """Packet handler for unit testing""" packet_handlers = { 1: _handler1 } handler = _TestPacketHandler() packet = SSHPacket(b'') self.assertTrue(handler.process_packet(1, packet)) self.assertFalse(handler.process_packet(2, packet)) asyncssh-1.11.1/tests/test_process.py000066400000000000000000001035731320320510200176540ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH process API""" import asyncio import io import os import socket import sys import unittest import asyncssh from .server import ServerTestCase from .util import asynctest, echo def _handle_client(process): """Handle a new client request""" # pylint: disable=no-self-use action = process.command or process.subsystem if not action: action = 'echo' if action == 'break': try: yield from process.stdin.readline() except asyncssh.BreakReceived as exc: process.exit_with_signal('ABRT', False, str(exc.msec)) elif action == 'delay': yield from asyncio.sleep(1) yield from echo(process.stdin, process.stdout, process.stderr) elif action == 'echo': yield from echo(process.stdin, process.stdout, process.stderr) elif action == 'exit_status': process.channel.set_encoding('utf-8') process.stderr.write('Exiting with status 1') process.exit(1) elif action == 'env': process.channel.set_encoding('utf-8') process.stdout.write(process.env.get('TEST', '')) elif action == 'redirect_stdin': yield from process.redirect_stdin(process.stdout) yield from process.stdout.drain() elif action == 'redirect_stdout': yield from process.redirect_stdout(process.stdin) yield from process.stdout.drain() elif action == 'redirect_stderr': yield from process.redirect_stderr(process.stdin) yield from process.stderr.drain() elif action == 'term': info = str((process.get_terminal_type(), process.get_terminal_size(), process.get_terminal_mode(asyncssh.PTY_OP_OSPEED))) process.channel.set_encoding('utf-8') process.stdout.write(info) elif action == 'term_size': try: yield from process.stdin.readline() except asyncssh.TerminalSizeChanged as exc: process.exit_with_signal('ABRT', False, '%sx%s' % (exc.width, exc.height)) else: process.exit(255) process.close() yield from process.wait_closed() class _TestProcess(ServerTestCase): """Unit tests for AsyncSSH process API""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server(process_factory=_handle_client, session_encoding=None)) class _TestProcessBasic(_TestProcess): """Unit tests for AsyncSSH process basic functions""" @asynctest def test_shell(self): """Test starting a remote shell""" data = str(id(self)) with (yield from self.connect()) as conn: process = yield from conn.create_process(env={'TEST': 'test'}) process.stdin.write(data) process.stdin.write_eof() result = yield from process.wait() self.assertEqual(result.env, {'TEST': 'test'}) self.assertEqual(result.command, None) self.assertEqual(result.subsystem, None) self.assertEqual(result.exit_status, None) self.assertEqual(result.exit_signal, None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_command(self): """Test executing a remote command""" data = str(id(self)) with (yield from self.connect()) as conn: process = yield from conn.create_process('echo') process.stdin.write(data) process.stdin.write_eof() result = yield from process.wait() self.assertEqual(result.command, 'echo') self.assertEqual(result.subsystem, None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_subsystem(self): """Test starting a remote subsystem""" data = str(id(self)) with (yield from self.connect()) as conn: process = yield from conn.create_process(subsystem='echo') process.stdin.write(data) process.stdin.write_eof() result = yield from process.wait() self.assertEqual(result.command, None) self.assertEqual(result.subsystem, 'echo') self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_communicate(self): """Test communicate""" data = str(id(self)) with (yield from self.connect()) as conn: with (yield from conn.create_process()) as process: stdout_data, stderr_data = yield from process.communicate(data) self.assertEqual(stdout_data, data) self.assertEqual(stderr_data, data) @asynctest def test_communicate_paused(self): """Test communicate when reading is already paused""" data = 4*1024*1024*'*' with (yield from self.connect()) as conn: with (yield from conn.create_process(input=data)) as process: yield from asyncio.sleep(1) stdout_data, stderr_data = yield from process.communicate() self.assertEqual(stdout_data, data) self.assertEqual(stderr_data, data) @asynctest def test_env(self): """Test sending environment""" with (yield from self.connect()) as conn: process = yield from conn.create_process('env', env={'TEST': 'test'}) result = yield from process.wait() self.assertEqual(result.stdout, 'test') @asynctest def test_terminal_info(self): """Test sending terminal information""" modes = {asyncssh.PTY_OP_OSPEED: 9600} with (yield from self.connect()) as conn: process = yield from conn.create_process('term', term_type='ansi', term_size=(80, 24), term_modes=modes) result = yield from process.wait() self.assertEqual(result.stdout, "('ansi', (80, 24, 0, 0), 9600)") @asynctest def test_change_terminal_size(self): """Test changing terminal size""" with (yield from self.connect()) as conn: process = yield from conn.create_process('term_size', term_type='ansi') process.change_terminal_size(80, 24) result = yield from process.wait() self.assertEqual(result.exit_signal[2], '80x24') @asynctest def test_break(self): """Test sending a break""" with (yield from self.connect()) as conn: process = yield from conn.create_process('break') process.send_break(1000) result = yield from process.wait() self.assertEqual(result.exit_signal[2], '1000') @asynctest def test_signal(self): """Test sending a signal""" with (yield from self.connect()) as conn: process = yield from conn.create_process() process.send_signal('HUP') result = yield from process.wait() self.assertEqual(result.exit_signal[0], 'HUP') @asynctest def test_terminate(self): """Test sending a terminate signal""" with (yield from self.connect()) as conn: process = yield from conn.create_process() process.terminate() result = yield from process.wait() self.assertEqual(result.exit_signal[0], 'TERM') @asynctest def test_kill(self): """Test sending a kill signal""" with (yield from self.connect()) as conn: process = yield from conn.create_process() process.kill() result = yield from process.wait() self.assertEqual(result.exit_signal[0], 'KILL') @asynctest def test_exit_status(self): """Test checking exit status""" with (yield from self.connect()) as conn: result = yield from conn.run('exit_status') self.assertEqual(result.exit_status, 1) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, 'Exiting with status 1') @asynctest def test_raise_on_exit_status(self): """Test raising an exception on non-zero exit status""" with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.ProcessError) as exc: yield from conn.run('exit_status', env={'TEST': 'test'}, check=True) self.assertEqual(exc.exception.env, {'TEST': 'test'}) self.assertEqual(exc.exception.command, 'exit_status') self.assertEqual(exc.exception.subsystem, None) self.assertEqual(exc.exception.exit_status, 1) self.assertEqual(exc.exception.reason, 'Process exited with non-zero exit status 1') @asynctest def test_exit_signal(self): """Test checking exit signal""" with (yield from self.connect()) as conn: process = yield from conn.create_process() process.send_signal('HUP') result = yield from process.wait() self.assertEqual(result.exit_status, -1) self.assertEqual(result.exit_signal[0], 'HUP') @asynctest def test_raise_on_exit_signal(self): """Test raising an exception on exit signal""" with (yield from self.connect()) as conn: process = yield from conn.create_process() with self.assertRaises(asyncssh.ProcessError) as exc: process.send_signal('HUP') yield from process.wait(check=True) self.assertEqual(exc.exception.exit_status, -1) self.assertEqual(exc.exception.exit_signal[0], 'HUP') self.assertEqual(exc.exception.reason, 'Process exited with signal HUP') @asynctest def test_split_unicode(self): """Test Unicode split across blocks""" data = '\u2000test\u2000' with open('stdin', 'w', encoding='utf-8') as file: file.write(data) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin='stdin', bufsize=2) self.assertEqual(result.stdout, data) @asynctest def test_invalid_unicode(self): """Test invalid Unicode data""" data = b'\xfftest' with open('stdin', 'wb') as file: file.write(data) with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.DisconnectError): yield from conn.run('echo', stdin='stdin') @asynctest def test_incomplete_unicode(self): """Test incomplete Unicode data""" data = '\u2000'.encode('utf-8')[:2] with open('stdin', 'wb') as file: file.write(data) with (yield from self.connect()) as conn: with self.assertRaises(asyncssh.DisconnectError): yield from conn.run('echo', stdin='stdin') @asynctest def test_disconnect(self): """Test collecting output from a disconnected channel""" data = str(id(self)) with (yield from self.connect()) as conn: process = yield from conn.create_process() process.stdin.write(data) process.send_signal('ABRT') result = yield from process.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_unknown_action(self): """Test unknown action""" with (yield from self.connect()) as conn: result = yield from conn.run('unknown') self.assertEqual(result.exit_status, 255) class _TestProcessRedirection(_TestProcess): """Unit tests for AsyncSSH process I/O redirection""" @asynctest def test_input(self): """Test with input from a string""" data = str(id(self)) with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_devnull(self): """Test with stdin redirected to DEVNULL""" with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=asyncssh.DEVNULL) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, '') @asynctest def test_stdin_file(self): """Test with stdin redirected to a file""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin='stdin') self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_binary_file(self): """Test with stdin redirected to a file in binary mode""" data = str(id(self)).encode() + b'\xff' with open('stdin', 'wb') as file: file.write(data) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin='stdin', encoding=None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_open_file(self): """Test with stdin redirected to an open file""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) file = open('stdin', 'r') with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=file) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_open_binary_file(self): """Test with stdin redirected to an open file in binary mode""" data = str(id(self)).encode() + b'\xff' with open('stdin', 'wb') as file: file.write(data) file = open('stdin', 'rb') with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=file, encoding=None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_stringio(self): """Test with stdin redirected to a StringIO object""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) file = io.StringIO(data) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=file) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_bytesio(self): """Test with stdin redirected to a BytesIO object""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) file = io.BytesIO(data.encode('ascii')) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=file) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_process(self): """Test with stdin redirected to another SSH process""" data = str(id(self)) with (yield from self.connect()) as conn: proc1 = yield from conn.create_process(input=data) proc2 = yield from conn.create_process(stdin=proc1.stdout) result = yield from proc2.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdout_devnull(self): """Test with stdout redirected to DEVNULL""" data = str(id(self)) with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=asyncssh.DEVNULL) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_stdout_file(self): """Test with stdout redirected to a file""" data = str(id(self)) with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout='stdout') with open('stdout', 'r') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_stdout_binary_file(self): """Test with stdout redirected to a file in binary mode""" data = str(id(self)).encode() + b'\xff' with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout='stdout', encoding=None) with open('stdout', 'rb') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, b'') self.assertEqual(result.stderr, data) @asynctest def test_stdout_open_file(self): """Test with stdout redirected to an open file""" data = str(id(self)) file = open('stdout', 'w') with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=file) with open('stdout', 'r') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_stdout_open_binary_file(self): """Test with stdout redirected to an open binary file""" data = str(id(self)).encode() + b'\xff' file = open('stdout', 'wb') with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=file, encoding=None) with open('stdout', 'rb') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, b'') self.assertEqual(result.stderr, data) @asynctest def test_stdout_stringio(self): """Test with stdout redirected to a StringIO""" class _StringIOTest(io.StringIO): """Test class for StringIO which preserves output after close""" def __init__(self): super().__init__() self.output = None def close(self): if self.output is None: self.output = self.getvalue() super().close() data = str(id(self)) file = _StringIOTest() with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=file) self.assertEqual(file.output, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_stdout_bytesio(self): """Test with stdout redirected to a BytesIO""" class _BytesIOTest(io.BytesIO): """Test class for BytesIO which preserves output after close""" def __init__(self): super().__init__() self.output = None def close(self): if self.output is None: self.output = self.getvalue() super().close() data = str(id(self)) file = _BytesIOTest() with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=file) self.assertEqual(file.output, data.encode('ascii')) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_stdout_process(self): """Test with stdout redirected to another SSH process""" data = str(id(self)) with (yield from self.connect()) as conn: with (yield from conn.create_process()) as proc2: proc1 = yield from conn.create_process(stdout=proc2.stdin) proc1.stdin.write(data) proc1.stdin.write_eof() result = yield from proc2.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_change_stdout(self): """Test changing stdout of an open process""" with (yield from self.connect()) as conn: process = yield from conn.create_process(stdout='stdout') process.stdin.write('xxx') yield from asyncio.sleep(0.1) yield from process.redirect_stdout(asyncssh.PIPE) process.stdin.write('yyy') process.stdin.write_eof() result = yield from process.wait() with open('stdout', 'r') as file: stdout_data = file.read() self.assertEqual(stdout_data, 'xxx') self.assertEqual(result.stdout, 'yyy') self.assertEqual(result.stderr, 'xxxyyy') @asynctest def test_change_stdin_process(self): """Test changing stdin of an open process reading from another""" data = str(id(self)) with (yield from self.connect()) as conn: with (yield from conn.create_process()) as proc2: proc1 = yield from conn.create_process(stdout=proc2.stdin) proc1.stdin.write(data) yield from asyncio.sleep(0.1) yield from proc2.redirect_stdin(asyncssh.PIPE) proc2.stdin.write(data) yield from asyncio.sleep(0.1) yield from proc2.redirect_stdin(proc1.stdout) proc1.stdin.write_eof() result = yield from proc2.wait() self.assertEqual(result.stdout, data+data) self.assertEqual(result.stderr, data+data) @asynctest def test_change_stdout_process(self): """Test changing stdout of an open process sending to another""" data = str(id(self)) with (yield from self.connect()) as conn: with (yield from conn.create_process()) as proc2: proc1 = yield from conn.create_process(stdout=proc2.stdin) proc1.stdin.write(data) yield from asyncio.sleep(0.1) yield from proc1.redirect_stdout(asyncssh.DEVNULL) proc1.stdin.write(data) yield from asyncio.sleep(0.1) yield from proc1.redirect_stdout(proc2.stdin) proc1.stdin.write_eof() result = yield from proc2.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stderr_stdout(self): """Test with stderr redirected to stdout""" data = str(id(self)) with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stderr=asyncssh.STDOUT) self.assertEqual(result.stdout, data+data) @asynctest def test_server_redirect_stdin(self): """Test redirect on server of stdin""" data = str(id(self)) with (yield from self.connect()) as conn: result = yield from conn.run('redirect_stdin', input=data) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, '') @asynctest def test_server_redirect_stdout(self): """Test redirect on server of stdout""" data = str(id(self)) with (yield from self.connect()) as conn: result = yield from conn.run('redirect_stdout', input=data) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, '') @asynctest def test_server_redirect_stderr(self): """Test redirect on server of stderr""" data = str(id(self)) with (yield from self.connect()) as conn: result = yield from conn.run('redirect_stderr', input=data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_pause_file_reader(self): """Test pausing and resuming reading from a file""" data = 4*1024*1024*'*' with open('stdin', 'w') as file: file.write(data) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin='stdin', stderr=asyncssh.DEVNULL) self.assertEqual(result.stdout, data) @asynctest def test_pause_process_reader(self): """Test pausing and resuming reading from another SSH process""" data = 4*1024*1024*'*' with (yield from self.connect()) as conn: proc1 = yield from conn.create_process(input=data) proc2 = yield from conn.create_process('delay', stdin=proc1.stdout, stderr=asyncssh.DEVNULL) proc3 = yield from conn.create_process('delay', stdin=proc1.stderr, stderr=asyncssh.DEVNULL) result2, result3 = yield from asyncio.gather(proc2.wait(), proc3.wait()) self.assertEqual(result2.stdout, data) self.assertEqual(result3.stdout, data) @asynctest def test_redirect_stdin_when_paused(self): """Test redirecting stdin when write is paused""" data = 4*1024*1024*'*' with open('stdin', 'w') as file: file.write(data) with (yield from self.connect()) as conn: process = yield from conn.create_process() process.stdin.write(data) yield from process.redirect_stdin('stdin') result = yield from process.wait() self.assertEqual(result.stdout, data+data) self.assertEqual(result.stderr, data+data) @asynctest def test_redirect_process_when_paused(self): """Test redirecting away from a process when write is paused""" data = 4*1024*1024*'*' with (yield from self.connect()) as conn: proc1 = yield from conn.create_process(input=data) proc2 = yield from conn.create_process('delay', stdin=proc1.stdout) proc3 = yield from conn.create_process('delay', stdin=proc1.stderr) yield from proc1.redirect_stderr(asyncssh.DEVNULL) result = yield from proc2.wait() proc3.close() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_consecutive_redirect(self): """Test consecutive redirects using drain""" data = 4*1024*1024*'*' with open('stdin', 'w') as file: file.write(data) with (yield from self.connect()) as conn: process = yield from conn.create_process() yield from process.redirect_stdin('stdin', send_eof=False) yield from process.stdin.drain() yield from process.redirect_stdin('stdin') result = yield from process.wait() self.assertEqual(result.stdout, data+data) self.assertEqual(result.stderr, data+data) @unittest.skipIf(sys.platform == 'win32', 'skip pipe tests on Windows') class _TestProcessPipes(_TestProcess): """Unit tests for AsyncSSH process I/O using pipes""" @asynctest def test_stdin_pipe(self): """Test with stdin redirected to a pipe""" data = str(id(self)) rpipe, wpipe = os.pipe() os.write(wpipe, data.encode()) os.close(wpipe) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=rpipe) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_text_pipe(self): """Test with stdin redirected to a pipe in text mode""" data = str(id(self)) rpipe, wpipe = os.pipe() rpipe = os.fdopen(rpipe, 'r') wpipe = os.fdopen(wpipe, 'w') wpipe.write(data) wpipe.close() with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=rpipe) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdin_binary_pipe(self): """Test with stdin redirected to a pipe in binary mode""" data = str(id(self)).encode() + b'\xff' rpipe, wpipe = os.pipe() os.write(wpipe, data) os.close(wpipe) with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=rpipe, encoding=None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_stdout_pipe(self): """Test with stdout redirected to a pipe""" data = str(id(self)) rpipe, wpipe = os.pipe() with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=wpipe) stdout_data = os.read(rpipe, 1024) os.close(rpipe) self.assertEqual(stdout_data.decode(), data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_stdout_text_pipe(self): """Test with stdout redirected to a pipe in text mode""" data = str(id(self)) rpipe, wpipe = os.pipe() rpipe = os.fdopen(rpipe, 'r') wpipe = os.fdopen(wpipe, 'w') with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=wpipe) stdout_data = rpipe.read(1024) rpipe.close() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest def test_stdout_binary_pipe(self): """Test with stdout redirected to a pipe in binary mode""" data = str(id(self)).encode() + b'\xff' rpipe, wpipe = os.pipe() with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=wpipe, encoding=None) stdout_data = os.read(rpipe, 1024) os.close(rpipe) self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, b'') self.assertEqual(result.stderr, data) @unittest.skipIf(sys.platform == 'win32', 'skip socketpair tests on Windows') class _TestProcessSocketPair(_TestProcess): """Unit tests for AsyncSSH process I/O using socketpair""" @asynctest def test_stdin_socketpair(self): """Test with stdin redirected to a socketpair""" data = str(id(self)) sock1, sock2 = socket.socketpair() sock1.send(data.encode()) sock1.close() with (yield from self.connect()) as conn: result = yield from conn.run('echo', stdin=sock2) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest def test_change_stdin(self): """Test changing stdin of an open process""" sock1, sock2 = socket.socketpair() sock3, sock4 = socket.socketpair() sock1.send(b'xxx') sock3.send(b'yyy') with (yield from self.connect()) as conn: process = yield from conn.create_process(stdin=sock2) yield from asyncio.sleep(0.1) yield from process.redirect_stdin(sock4) sock1.close() sock3.close() result = yield from process.wait() self.assertEqual(result.stdout, 'xxxyyy') self.assertEqual(result.stderr, 'xxxyyy') @asynctest def test_stdout_socketpair(self): """Test with stdout redirected to a socketpair""" data = str(id(self)) sock1, sock2 = socket.socketpair() with (yield from self.connect()) as conn: result = yield from conn.run('echo', input=data, stdout=sock1) stdout_data = sock2.recv(1024) sock2.close() self.assertEqual(stdout_data.decode(), data) self.assertEqual(result.stderr, data) @asynctest def test_pause_socketpair_reader(self): """Test pausing and resuming reading from a socketpair""" data = 4*1024*1024*'*' sock1, sock2 = socket.socketpair() _, writer = yield from asyncio.open_unix_connection(sock=sock1) writer.write(data.encode()) writer.close() with (yield from self.connect()) as conn: result = yield from conn.run('delay', stdin=sock2, stderr=asyncssh.DEVNULL) self.assertEqual(result.stdout, data) @asynctest def test_pause_socketpair_writer(self): """Test pausing and resuming writing to a socketpair""" data = 4*1024*1024*'*' rsock1, wsock1 = socket.socketpair() rsock2, wsock2 = socket.socketpair() reader1, writer1 = yield from asyncio.open_unix_connection(sock=rsock1) reader2, writer2 = yield from asyncio.open_unix_connection(sock=rsock2) with (yield from self.connect()) as conn: process = yield from conn.create_process(input=data) yield from asyncio.sleep(1) yield from process.redirect_stdout(wsock1) yield from process.redirect_stderr(wsock2) stdout_data, stderr_data = \ yield from asyncio.gather(reader1.read(), reader2.read()) writer1.close() writer2.close() yield from process.wait() self.assertEqual(stdout_data.decode(), data) self.assertEqual(stderr_data.decode(), data) asyncssh-1.11.1/tests/test_public_key.py000066400000000000000000002404051320320510200203200ustar00rootroot00000000000000# Copyright (c) 2014-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for reading and writing public and private keys Note: These tests look for the openssl and ssh-keygen commands in the user's path and will whenever possible use them to perform interoperability tests. Otherwise, these tests will only test AsyncSSH against itself. """ import binascii from datetime import datetime import os import shutil import subprocess import sys import asyncssh from asyncssh.asn1 import der_encode, BitString, ObjectIdentifier from asyncssh.asn1 import TaggedDERObject from asyncssh.packet import MPInt, String, UInt32 from asyncssh.pbe import pkcs1_decrypt from asyncssh.public_key import CERT_TYPE_USER, CERT_TYPE_HOST, SSHKey from asyncssh.public_key import SSHX509CertificateChain from asyncssh.public_key import decode_ssh_certificate from asyncssh.public_key import get_public_key_algs, get_certificate_algs from asyncssh.public_key import get_x509_certificate_algs from asyncssh.public_key import import_certificate_subject from .util import bcrypt_available, libnacl_available, x509_available from .util import make_certificate, run, TempDirTestCase _ES1_SHA1_DES = ObjectIdentifier('1.2.840.113549.1.5.10') _P12_RC4_40 = ObjectIdentifier('1.2.840.113549.1.12.1.2') _ES2 = ObjectIdentifier('1.2.840.113549.1.5.13') _ES2_PBKDF2 = ObjectIdentifier('1.2.840.113549.1.5.12') _ES2_AES128 = ObjectIdentifier('2.16.840.1.101.3.4.1.2') _ES2_DES3 = ObjectIdentifier('1.2.840.113549.3.7') try: _openssl_version = run('openssl version') except subprocess.CalledProcessError: # pragma: no cover _openssl_version = b'' _openssl_available = _openssl_version != b'' # The openssl "-v2prf" option is only available in OpenSSL 1.0.2 or later _openssl_supports_v2prf = _openssl_version >= b'OpenSSL 1.0.2' try: if sys.platform != 'win32': _openssh_version = run('ssh -V') else: # pragma: no cover _openssh_version = b'' except subprocess.CalledProcessError: # pragma: no cover _openssh_version = b'' _openssh_available = _openssh_version != b'' # GCM & Chacha tests require OpenSSH 6.9 due to a bug in earlier versions: # https://bugzilla.mindrot.org/show_bug.cgi?id=2366 _openssh_supports_gcm_chacha = _openssh_version >= b'OpenSSH_6.9' _openssh_supports_arcfour_blowfish_cast = (_openssh_available and _openssh_version < b'OpenSSH_7.6') # pylint: disable=bad-whitespace pkcs1_ciphers = (('aes128-cbc', '-aes128'), ('aes192-cbc', '-aes192'), ('aes256-cbc', '-aes256'), ('des-cbc', '-des'), ('des3-cbc', '-des3')) pkcs8_ciphers = ( ('aes128-cbc', 'sha224', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA224', _openssl_supports_v2prf), ('aes128-cbc', 'sha256', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA256', _openssl_supports_v2prf), ('aes128-cbc', 'sha384', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA384', _openssl_supports_v2prf), ('aes128-cbc', 'sha512', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA512', _openssl_supports_v2prf), ('des-cbc', 'md5', 1, '-v1 PBE-MD5-DES', _openssl_available), ('des-cbc', 'sha1', 1, '-v1 PBE-SHA1-DES', _openssl_available), ('des2-cbc', 'sha1', 1, '-v1 PBE-SHA1-2DES', _openssl_available), ('des3-cbc', 'sha1', 1, '-v1 PBE-SHA1-3DES', _openssl_available), ('rc4-40', 'sha1', 1, '-v1 PBE-SHA1-RC4-40', _openssl_available), ('rc4-128', 'sha1', 1, '-v1 PBE-SHA1-RC4-128', _openssl_available), ('aes128-cbc', 'sha1', 2, '-v2 aes-128-cbc', _openssl_available), ('aes192-cbc', 'sha1', 2, '-v2 aes-192-cbc', _openssl_available), ('aes256-cbc', 'sha1', 2, '-v2 aes-256-cbc', _openssl_available), ('blowfish-cbc', 'sha1', 2, '-v2 bf-cbc', _openssl_available), ('cast128-cbc', 'sha1', 2, '-v2 cast-cbc', _openssl_available), ('des-cbc', 'sha1', 2, '-v2 des-cbc', _openssl_available), ('des3-cbc', 'sha1', 2, '-v2 des-ede3-cbc', _openssl_available)) openssh_ciphers = ( ('aes128-gcm@openssh.com', _openssh_supports_gcm_chacha), ('aes256-gcm@openssh.com', _openssh_supports_gcm_chacha), ('arcfour', _openssh_supports_arcfour_blowfish_cast), ('arcfour128', _openssh_supports_arcfour_blowfish_cast), ('arcfour256', _openssh_supports_arcfour_blowfish_cast), ('blowfish-cbc', _openssh_supports_arcfour_blowfish_cast), ('cast128-cbc', _openssh_supports_arcfour_blowfish_cast), ('aes128-cbc', _openssh_available), ('aes192-cbc', _openssh_available), ('aes256-cbc', _openssh_available), ('aes128-ctr', _openssh_available), ('aes192-ctr', _openssh_available), ('aes256-ctr', _openssh_available), ('3des-cbc', _openssh_available) ) # pylint: enable=bad-whitespace # Only test Chacha if libnacl is installed if libnacl_available: # pragma: no branch openssh_ciphers += (('chacha20-poly1305@openssh.com', _openssh_supports_gcm_chacha),) def select_passphrase(cipher, pbe_version=0): """Randomize between string and bytes version of passphrase""" if cipher is None: return None elif os.urandom(1)[0] & 1: return 'passphrase' elif pbe_version == 1 and cipher in ('des2-cbc', 'des3-cbc', 'rc4-40', 'rc4-128'): return 'passphrase'.encode('utf-16-be') else: return 'passphrase'.encode('utf-8') class _TestPublicKey(TempDirTestCase): """Unit tests for public key modules""" # pylint: disable=too-many-public-methods keyclass = None base_format = None private_formats = () public_formats = () default_cert_version = '' x509_supported = False generate_args = () def __init__(self, methodName='runTest'): super().__init__(methodName) self.privkey = None self.pubkey = None self.privca = None self.pubca = None self.usercert = None self.hostcert = None self.rootx509 = None self.userx509 = None self.hostx509 = None self.otherx509 = None def make_certificate(self, *args, **kwargs): """Construct an SSH certificate""" return make_certificate(self.default_cert_version, *args, **kwargs) def validate_openssh(self, cert, cert_type, name): """Check OpenSSH certificate validation""" self.assertIsNone(cert.validate(cert_type, name)) def validate_x509(self, cert, user_principal=None): """Check X.509 certificate validation""" self.assertIsNone(cert.validate_chain([], [self.rootx509], [], 'any', user_principal, None)) with self.assertRaises(ValueError): cert.validate_chain([self.rootx509], [], [], 'any', None, None) chain = SSHX509CertificateChain.construct_from_certs([cert]) self.assertEqual(chain, decode_ssh_certificate(chain.data)) self.assertIsNone(chain.validate_chain([self.rootx509], [], [], 'any', user_principal, None)) self.assertIsNone(chain.validate_chain([self.rootx509], [], [self.otherx509], 'any', user_principal, None)) with self.assertRaises(ValueError): chain.validate_chain([], [], [], 'any', user_principal, None) with self.assertRaises(ValueError): chain.validate_chain([self.rootx509], [], [cert], 'any', user_principal, None) def check_private(self, format_name, passphrase=None): """Check for a private key match""" newkey = asyncssh.read_private_key('new', passphrase) algorithm = newkey.get_algorithm() keydata = newkey.export_private_key() pubdata = newkey.get_ssh_public_key() self.assertEqual(newkey, self.privkey) self.assertEqual(hash(newkey), hash(self.privkey)) keypair = asyncssh.load_keypairs(newkey, passphrase)[0] self.assertEqual(keypair.get_key_type(), 'local') self.assertEqual(keypair.get_algorithm(), algorithm) self.assertEqual(keypair.public_data, pubdata) self.assertIsNotNone(keypair.get_agent_private_key()) keypair = asyncssh.load_keypairs([keypair])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs(keydata)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs('new', passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([newkey])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([(newkey, None)])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([keydata])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([(keydata, None)])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs(['new'], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([('new', None)], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keylist = asyncssh.load_keypairs([]) self.assertEqual(keylist, []) if passphrase: with self.assertRaises((asyncssh.KeyEncryptionError, asyncssh.KeyImportError)): asyncssh.load_keypairs('new', 'xxx') else: newkey.write_private_key('list', format_name) newkey.append_private_key('list', format_name) keylist = asyncssh.load_keypairs('list') self.assertEqual(keylist[0].public_data, pubdata) self.assertEqual(keylist[1].public_data, pubdata) if self.x509_supported and format_name[-4:] == '-pem': cert = newkey.generate_x509_user_certificate(newkey, 'OU=user') chain = SSHX509CertificateChain.construct_from_certs([cert]) cert.write_certificate('new_cert') keypair = asyncssh.load_keypairs(('new', 'new_cert'), passphrase)[0] self.assertEqual(keypair.public_data, chain.data) self.assertIsNotNone(keypair.get_agent_private_key()) newkey.write_private_key('new_bundle', format_name, passphrase) cert.append_certificate('new_bundle', 'pem') keypair = asyncssh.load_keypairs('new_bundle', passphrase)[0] self.assertEqual(keypair.public_data, chain.data) with self.assertRaises(OSError): asyncssh.load_keypairs(('new', 'not_found'), passphrase) def check_public(self, format_name): """Check for a public key match""" newkey = asyncssh.read_public_key('new') pubdata = newkey.export_public_key() self.assertEqual(newkey, self.pubkey) self.assertEqual(hash(newkey), hash(self.pubkey)) keypair = asyncssh.load_public_keys('new')[0] self.assertEqual(keypair, newkey) keypair = asyncssh.load_public_keys([newkey])[0] self.assertEqual(keypair, newkey) keypair = asyncssh.load_public_keys([pubdata])[0] self.assertEqual(keypair, newkey) keypair = asyncssh.load_public_keys(['new'])[0] self.assertEqual(keypair, newkey) newkey.write_public_key('list', format_name) newkey.append_public_key('list', format_name) keylist = asyncssh.load_public_keys('list') self.assertEqual(keylist[0], newkey) self.assertEqual(keylist[1], newkey) def check_certificate(self, cert_type, format_name): """Check for a certificate match""" cert = asyncssh.read_certificate('cert') certdata = cert.export_certificate() self.assertEqual(cert.key, self.pubkey) if cert.is_x509: self.validate_x509(cert) else: self.validate_openssh(cert, cert_type, 'name') certlist = asyncssh.load_certificates(cert) self.assertEqual(certlist[0], cert) self.assertEqual(hash(certlist[0]), hash(cert)) if cert.is_x509: self.assertEqual(certlist[0].x509_cert, cert.x509_cert) self.assertEqual(hash(certlist[0].x509_cert), hash(cert.x509_cert)) certlist = asyncssh.load_certificates(certdata) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates([cert]) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates([certdata]) self.assertEqual(certlist[0], cert) cert.write_certificate('list', format_name) cert.append_certificate('list', format_name) certlist = asyncssh.load_certificates('list') self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) certlist = asyncssh.load_certificates(['list', [cert]]) self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) self.assertEqual(certlist[2], cert) certlist = asyncssh.load_certificates(['list', certdata]) self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) self.assertEqual(certlist[2], cert) def import_pkcs1_private(self, fmt, cipher=None, args=None): """Check import of a PKCS#1 private key""" format_name = 'pkcs1-%s' % fmt if _openssl_available: # pragma: no branch if cipher: run('openssl %s %s -in priv -inform pem -out new -outform %s ' '-passout pass:passphrase' % (self.keyclass, args, fmt)) else: run('openssl %s -in priv -inform pem -out new -outform %s' % (self.keyclass, fmt)) else: # pragma: no cover self.privkey.write_private_key('new', format_name, select_passphrase(cipher), cipher) self.check_private(format_name, select_passphrase(cipher)) def export_pkcs1_private(self, fmt, cipher=None): """Check export of a PKCS#1 private key""" format_name = 'pkcs1-%s' % fmt self.privkey.write_private_key('privout', format_name, select_passphrase(cipher), cipher) if _openssl_available: # pragma: no branch if cipher: run('openssl %s -in privout -inform %s -out new -outform pem ' '-passin pass:passphrase' % (self.keyclass, fmt)) else: run('openssl %s -in privout -inform %s -out new -outform pem' % (self.keyclass, fmt)) else: # pragma: no cover priv = asyncssh.read_private_key('privout', select_passphrase(cipher)) priv.write_private_key('new', format_name) self.check_private(format_name) def import_pkcs1_public(self, fmt): """Check import of a PKCS#1 public key""" format_name = 'pkcs1-%s' % fmt if (not _openssl_available or self.keyclass == 'dsa' or _openssl_version < b'OpenSSL 1.0.0'): # pragma: no cover # OpenSSL no longer has support for PKCS#1 DSA, and PKCS#1 # RSA is not supported before OpenSSL 1.0.0, so we only test # against ourselves in these cases. self.pubkey.write_public_key('new', format_name) else: run('openssl %s -pubin -in pub -inform pem -RSAPublicKey_out ' '-out new -outform %s' % (self.keyclass, fmt)) self.check_public(format_name) def export_pkcs1_public(self, fmt): """Check export of a PKCS#1 public key""" format_name = 'pkcs1-%s' % fmt self.privkey.write_public_key('pubout', format_name) if not _openssl_available or self.keyclass == 'dsa': # pragma: no cover # OpenSSL no longer has support for PKCS#1 DSA, so we can # only test against ourselves. pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', 'pkcs1-%s' % fmt) else: run('openssl %s -RSAPublicKey_in -in pubout -inform %s -out new ' '-outform pem' % (self.keyclass, fmt)) self.check_public(format_name) def import_pkcs8_private(self, fmt, use_openssl, cipher=None, hash_alg=None, pbe_version=None, args=None): """Check import of a PKCS#8 private key""" format_name = 'pkcs8-%s' % fmt if use_openssl: # pragma: no branch if cipher: run('openssl pkcs8 -topk8 %s -in priv -inform pem -out new ' '-outform %s -passout pass:passphrase' % (args, fmt)) else: run('openssl pkcs8 -topk8 -nocrypt -in priv -inform pem ' '-out new -outform %s' % fmt) else: # pragma: no cover self.privkey.write_private_key('new', format_name, select_passphrase(cipher, pbe_version), cipher, hash_alg, pbe_version) self.check_private(format_name, select_passphrase(cipher, pbe_version)) def export_pkcs8_private(self, fmt, use_openssl, cipher=None, hash_alg=None, pbe_version=None): """Check export of a PKCS#8 private key""" format_name = 'pkcs8-%s' % fmt self.privkey.write_private_key('privout', format_name, select_passphrase(cipher, pbe_version), cipher, hash_alg, pbe_version) if use_openssl: # pragma: no branch if cipher: run('openssl pkcs8 -in privout -inform %s -out new ' '-outform pem -passin pass:passphrase' % fmt) else: run('openssl pkcs8 -nocrypt -in privout -inform %s -out new ' '-outform pem' % fmt) else: # pragma: no cover priv = asyncssh.read_private_key('privout', select_passphrase(cipher, pbe_version)) priv.write_private_key('new', format_name) self.check_private(format_name) def import_pkcs8_public(self, fmt): """Check import of a PKCS#8 public key""" format_name = 'pkcs8-%s' % fmt if _openssl_available: # pragma: no branch run('openssl %s -pubin -in pub -inform pem -out new -outform %s' % (self.keyclass, fmt)) else: # pragma: no cover self.pubkey.write_public_key('new', format_name) self.check_public(format_name) def export_pkcs8_public(self, fmt): """Check export of a PKCS#8 public key""" format_name = 'pkcs8-%s' % fmt self.privkey.write_public_key('pubout', format_name) if _openssl_available: # pragma: no branch run('openssl %s -pubin -in pubout -inform %s -out new ' '-outform pem' % (self.keyclass, fmt)) else: # pragma: no cover pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', format_name) self.check_public(format_name) def import_openssh_private(self, use_openssh, cipher=None): """Check import of an OpenSSH private key""" if use_openssh: # pragma: no branch shutil.copy('priv', 'new') if cipher: run('ssh-keygen -p -a 128 -N passphrase -Z %s -o -f new' % cipher) else: run('ssh-keygen -p -N "" -o -f new') else: # pragma: no cover self.privkey.write_private_key('new', 'openssh', select_passphrase(cipher), cipher) self.check_private('openssh', select_passphrase(cipher)) def export_openssh_private(self, use_openssh, cipher=None): """Check export of an OpenSSH private key""" self.privkey.write_private_key('new', 'openssh', select_passphrase(cipher), cipher) if use_openssh: # pragma: no branch os.chmod('new', 0o600) if cipher: run('ssh-keygen -p -P passphrase -N "" -o -f new') else: run('ssh-keygen -p -N "" -o -f new') else: # pragma: no cover priv = asyncssh.read_private_key('new', select_passphrase(cipher)) priv.write_private_key('new', 'openssh') self.check_private('openssh') def import_openssh_public(self): """Check import of an OpenSSH public key""" shutil.copy('sshpub', 'new') self.check_public('openssh') def export_openssh_public(self): """Check export of an OpenSSH public key""" self.privkey.write_public_key('pubout', 'openssh') if _openssh_available: # pragma: no branch run('ssh-keygen -e -f pubout -m rfc4716 > new') else: # pragma: no cover pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', 'rfc4716') self.check_public('openssh') def import_openssh_certificate(self, cert_type, cert): """Check import of an OpenSSH certificate""" shutil.copy(cert, 'cert') self.check_certificate(cert_type, 'openssh') def export_openssh_certificate(self, cert_type, cert): """Check export of an OpenSSH certificate""" cert.write_certificate('certout', 'openssh') if _openssh_available: # pragma: no branch run('ssh-keygen -e -f certout -m rfc4716 > cert') else: # pragma: no cover cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'rfc4716') self.check_certificate(cert_type, 'openssh') def import_rfc4716_public(self): """Check import of an RFC4716 public key""" if _openssh_available: # pragma: no branch run('ssh-keygen -e -f sshpub -m rfc4716 > new') else: # pragma: no cover self.pubkey.write_public_key('new', 'rfc4716') self.check_public('rfc4716') def export_rfc4716_public(self): """Check export of an RFC4716 public key""" self.pubkey.write_public_key('pubout', 'rfc4716') if _openssh_available: # pragma: no branch run('ssh-keygen -i -f pubout -m rfc4716 > new') else: # pragma: no cover pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', 'openssh') self.check_public('rfc4716') def import_rfc4716_certificate(self, cert_type, cert): """Check import of an RFC4716 certificate""" if _openssh_available: # pragma: no branch run('ssh-keygen -e -f %s -m rfc4716 > cert' % cert) else: # pragma: no cover if cert_type == CERT_TYPE_USER: cert = self.usercert else: cert = self.hostcert cert.write_certificate('cert', 'rfc4716') self.check_certificate(cert_type, 'rfc4716') def export_rfc4716_certificate(self, cert_type, cert): """Check export of an RFC4716 certificate""" cert.write_certificate('certout', 'rfc4716') if _openssh_available: # pragma: no branch run('ssh-keygen -i -f certout -m rfc4716 > cert') else: # pragma: no cover cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'openssh') self.check_certificate(cert_type, 'rfc4716') def import_der_x509_certificate(self, cert_type, cert): """Check import of a DER X.509 certificate""" if cert_type == CERT_TYPE_USER: cert = self.userx509 else: cert = self.hostx509 cert.write_certificate('cert', 'der') self.check_certificate(cert_type, 'der') def export_der_x509_certificate(self, cert_type, cert): """Check export of a DER X.509 certificate""" cert.write_certificate('certout', 'der') cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'openssh') self.check_certificate(cert_type, 'der') def import_pem_x509_certificate(self, cert_type, cert): """Check import of a PEM X.509 certificate""" if cert_type == CERT_TYPE_USER: cert = self.userx509 else: cert = self.hostx509 cert.write_certificate('cert', 'pem') self.check_certificate(cert_type, 'pem') def export_pem_x509_certificate(self, cert_type, cert): """Check export of a PEM X.509 certificate""" cert.write_certificate('certout', 'pem') cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'openssh') self.check_certificate(cert_type, 'pem') def import_openssh_x509_certificate(self, cert_type, cert): """Check import of an OpenSSH X.509 certificate""" if cert_type == CERT_TYPE_USER: cert = self.userx509 else: cert = self.hostx509 cert.write_certificate('cert') self.check_certificate(cert_type, 'openssh') def export_openssh_x509_certificate(self, cert_type, cert): """Check export of an OpenSSH X.509 certificate""" cert.write_certificate('certout') cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'pem') self.check_certificate(cert_type, 'openssh') def check_encode_errors(self): """Check error code paths in key encoding""" for fmt in ('pkcs1-der', 'pkcs1-pem', 'pkcs8-der', 'pkcs8-pem', 'openssh', 'rfc4716', 'xxx'): with self.subTest('Encode private from public (%s)' % fmt): with self.assertRaises(asyncssh.KeyExportError): self.pubkey.export_private_key(fmt) with self.subTest('Encode with unknown key format'): with self.assertRaises(asyncssh.KeyExportError): self.privkey.export_public_key('xxx') with self.subTest('Encode encrypted pkcs1-der'): with self.assertRaises(asyncssh.KeyExportError): self.privkey.export_private_key('pkcs1-der', 'x') if self.keyclass == 'ec': with self.subTest('Encode EC public key with PKCS#1'): with self.assertRaises(asyncssh.KeyExportError): self.privkey.export_public_key('pkcs1-pem') if 'pkcs1' in self.private_formats: with self.subTest('Encode with unknown PKCS#1 cipher'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs1-pem', 'x', 'xxx') if 'pkcs8' in self.private_formats: with self.subTest('Encode with unknown PKCS#8 cipher'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs8-pem', 'x', 'xxx') with self.subTest('Encode with unknown PKCS#8 hash'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs8-pem', 'x', 'aes128-cbc', 'xxx') with self.subTest('Encode with unknown PKCS#8 version'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs8-pem', 'x', 'aes128-cbc', 'sha1', 3) if bcrypt_available: # pragma: no branch with self.subTest('Encode with unknown openssh cipher'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('openssh', 'x', 'xxx') with self.subTest('Encode agent cert private from public'): with self.assertRaises(asyncssh.KeyExportError): self.pubkey.encode_agent_cert_private() def check_decode_errors(self): """Check error code paths in key decoding""" private_errors = [ ('Non-ASCII', '\xff'), ('Incomplete ASN.1', b''), ('Invalid PKCS#1', der_encode(None)), ('Invalid PKCS#1 params', der_encode((1, b'', TaggedDERObject(0, b'')))), ('Invalid PKCS#1 EC named curve OID', der_encode((1, b'', TaggedDERObject(0, ObjectIdentifier('1.1'))))), ('Invalid PKCS#8', der_encode((0, (self.privkey.pkcs8_oid, ()), der_encode(None)))), ('Invalid PKCS#8 ASN.1', der_encode((0, (self.privkey.pkcs8_oid, None), b''))), ('Invalid PKCS#8 params', der_encode((1, (self.privkey.pkcs8_oid, b''), der_encode((1, b''))))), ('Invalid PEM header', b'-----BEGIN XXX-----\n'), ('Missing PEM footer', b'-----BEGIN PRIVATE KEY-----\n'), ('Invalid PEM key type', b'-----BEGIN XXX PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END XXX PRIVATE KEY-----'), ('Invalid PEM Base64', b'-----BEGIN PRIVATE KEY-----\n' b'X\n' b'-----END PRIVATE KEY-----'), ('Missing PKCS#1 passphrase', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'-----END DSA PRIVATE KEY-----'), ('Incomplete PEM ASN.1', b'-----BEGIN PRIVATE KEY-----\n' b'-----END PRIVATE KEY-----'), ('Missing PEM PKCS#8 passphrase', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#1 key', b'-----BEGIN DSA PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM PKCS#8 key', b'-----BEGIN PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END PRIVATE KEY-----'), ('Unknown format OpenSSH key', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b'XXX') + b'-----END OPENSSH PRIVATE KEY-----'), ('Incomplete OpenSSH key', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b'openssh-key-v1\0') + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH nkeys', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String(''), String(''), String(''), UInt32(2), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Missing OpenSSH passphrase', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('xxx'), String(''), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Mismatched OpenSSH check bytes', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('none'), String(''), String(''), UInt32(1), String(''), String(b''.join((UInt32(1), UInt32(2))))))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH algorithm', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('none'), String(''), String(''), UInt32(1), String(''), String(b''.join((UInt32(1), UInt32(1), String('xxx'))))))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH pad', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('none'), String(''), String(''), UInt32(1), String(''), String(b''.join((UInt32(1), UInt32(1), String('ssh-dss'), 5*MPInt(0), String(''), b'\0')))))) + b'-----END OPENSSH PRIVATE KEY-----') ] decrypt_errors = [ ('Invalid PKCS#1', der_encode(None)), ('Invalid PKCS#8', der_encode((0, (self.privkey.pkcs8_oid, ()), der_encode(None)))), ('Invalid PEM params', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'DEK-Info: XXX\n' b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM cipher', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'DEK-Info: XXX,00\n' b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM IV', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'DEK-Info: AES-256-CBC,XXX\n' b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM PKCS#8 encrypted data', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 encrypted header', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode((None, None))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 encryption algorithm', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((None, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_ES1_SHA1_DES, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 PKCS#12 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_P12_RC4_40, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 PKCS#12 salt', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_P12_RC4_40, (b'', 0)), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 PKCS#12 iteration count', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_P12_RC4_40, (b'x', 0)), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_ES2, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 KDF algorithm', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((None, None), (None, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 encryption algorithm', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, None), (None, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, None), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 salt', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (None, None)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 iteration count', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', None)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 PRF', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 0, None)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Unknown PEM PKCS#8 PBES2 PBKDF2 PRF', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 0, (ObjectIdentifier('1.1'), None))), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 0)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid length PEM PKCS#8 PBES2 IV', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 0)), (_ES2_AES128, b''))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid OpenSSH cipher', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('xxx'), String(''), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH kdf', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('xxx'), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH kdf data', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH salt', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(b''.join((String(b''), UInt32(128)))), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH encrypted data', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(b''.join((String(16*b'\0'), UInt32(128)))), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Unexpected OpenSSH trailing data', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(b''.join((String(16*b'\0'), UInt32(128)))), UInt32(1), String(''), String(''), String('xxx')))) + b'-----END OPENSSH PRIVATE KEY-----') ] public_errors = [ ('Non-ASCII', '\xff'), ('Incomplete ASN.1', b''), ('Invalid ASN.1', b'\x30'), ('Invalid PKCS#1', der_encode(None)), ('Invalid PKCS#8', der_encode(((self.pubkey.pkcs8_oid, ()), BitString(der_encode(None))))), ('Invalid PKCS#8 ASN.1', der_encode(((self.pubkey.pkcs8_oid, None), BitString(b'')))), ('Invalid PEM header', b'-----BEGIN XXX-----\n'), ('Missing PEM footer', b'-----BEGIN PUBLIC KEY-----\n'), ('Invalid PEM key type', b'-----BEGIN XXX PUBLIC KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END XXX PUBLIC KEY-----'), ('Invalid PEM Base64', b'-----BEGIN PUBLIC KEY-----\n' b'X\n' b'-----END PUBLIC KEY-----'), ('Incomplete PEM ASN.1', b'-----BEGIN PUBLIC KEY-----\n' b'-----END PUBLIC KEY-----'), ('Invalid PKCS#1 key data', b'-----BEGIN DSA PUBLIC KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END DSA PUBLIC KEY-----'), ('Invalid PKCS#8 key data', b'-----BEGIN PUBLIC KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END PUBLIC KEY-----'), ('Invalid OpenSSH', b'xxx'), ('Invalid OpenSSH Base64', b'ssh-dss X'), ('Unknown OpenSSH algorithm', b'ssh-dss ' + binascii.b2a_base64(String('xxx'))), ('Invalid OpenSSH body', b'ssh-dss ' + binascii.b2a_base64(String('ssh-dss'))), ('Invalid RFC4716 header', b'---- XXX ----\n'), ('Missing RFC4716 footer', b'---- BEGIN SSH2 PUBLIC KEY ----\n'), ('Invalid RFC4716 header', b'---- BEGIN SSH2 PUBLIC KEY ----\n' b'Comment: comment\n' b'XXX:\\\n' b'---- END SSH2 PUBLIC KEY ----\n'), ('Invalid RFC4716 Base64', b'---- BEGIN SSH2 PUBLIC KEY ----\n' b'X\n' b'---- END SSH2 PUBLIC KEY ----\n') ] keypair_errors = [ ('Mismatched certificate', (self.privca, self.usercert)), ('Invalid signature algorithm string', (self.privkey, None, 'xxx')), ('Invalid signature algorithm bytes', (self.privkey, None, b'xxx')) ] for fmt, data in private_errors: with self.subTest('Decode private (%s)' % fmt): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_private_key(data) for fmt, data in decrypt_errors: with self.subTest('Decrypt private (%s)' % fmt): with self.assertRaises((asyncssh.KeyEncryptionError, asyncssh.KeyImportError)): asyncssh.import_private_key(data, 'x') for fmt, data in public_errors: with self.subTest('Decode public (%s)' % fmt): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_public_key(data) for fmt, key in keypair_errors: with self.subTest('Load keypair (%s)' % fmt): with self.assertRaises(ValueError): asyncssh.load_keypairs([key]) def check_sshkey_base_errors(self): """Check SSHKey base class errors""" key = SSHKey(None) with self.subTest('SSHKey base class errors'): with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs1_private() with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs1_public() with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs8_private() with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs8_public() with self.assertRaises(asyncssh.KeyExportError): key.encode_ssh_private() with self.assertRaises(asyncssh.KeyExportError): key.encode_ssh_public() def check_sign_and_verify(self): """Check key signing and verification""" with self.subTest('Sign/verify test'): data = os.urandom(8) for cert in (None, self.usercert, self.userx509): keypair = asyncssh.load_keypairs([(self.privkey, cert)])[0] for sig_alg in keypair.sig_algorithms: with self.subTest('Good signature', sig_alg=sig_alg): keypair.set_sig_algorithm(sig_alg) sig = keypair.sign(data) with self.subTest('Good signature'): self.assertTrue(self.pubkey.verify(data, sig)) badsig = bytearray(sig) badsig[-1] ^= 0xff badsig = bytes(badsig) with self.subTest('Bad signature'): self.assertFalse(self.pubkey.verify(data, badsig)) with self.subTest('Missing signature'): self.assertFalse(self.pubkey.verify( data, String(self.pubkey.algorithm))) with self.subTest('Empty signature'): self.assertFalse(self.pubkey.verify( data, String(self.pubkey.algorithm) + String(b''))) with self.subTest('Sign with bad algorithm'): with self.assertRaises(ValueError): self.privkey.sign(data, 'xxx') with self.subTest('Verify with bad algorithm'): self.assertFalse(self.pubkey.verify( data, String('xxx') + String(''))) with self.subTest('Sign with public key'): with self.assertRaises(ValueError): self.pubkey.sign(data, self.pubkey.algorithm) def check_comment(self): """Check getting and setting comments""" with self.subTest('Comment test'): self.assertEqual(self.privkey.get_comment(), 'comment') self.assertEqual(self.pubkey.get_comment(), 'comment') key = asyncssh.import_private_key( self.privkey.export_private_key('openssh')) self.assertEqual(key.get_comment(), 'comment') key.set_comment('new_comment') self.assertEqual(key.get_comment(), 'new_comment') key.set_comment(b'new_comment') self.assertEqual(key.get_comment(), 'new_comment') for fmt in ('openssh', 'rfc4716'): key = asyncssh.import_public_key( self.pubkey.export_public_key(fmt)) self.assertEqual(key.get_comment(), 'comment') key = asyncssh.import_public_key( self.pubca.export_public_key(fmt)) self.assertEqual(key.get_comment(), None) key.set_comment('new_comment') self.assertEqual(key.get_comment(), 'new_comment') key.set_comment(b'new_comment') self.assertEqual(key.get_comment(), 'new_comment') for fmt in ('openssh', 'rfc4716'): cert = asyncssh.import_certificate( self.usercert.export_certificate(fmt)) self.assertEqual(cert.get_comment(), 'comment') cert = self.privca.generate_user_certificate( self.pubkey, 'name', comment='cert_comment') self.assertEqual(cert.get_comment(), 'cert_comment') cert = asyncssh.import_certificate( self.hostcert.export_certificate(fmt)) self.assertEqual(cert.get_comment(), 'comment') cert = self.privca.generate_host_certificate( self.pubkey, 'name', comment='cert_comment') self.assertEqual(cert.get_comment(), 'cert_comment') cert.set_comment('new_comment') self.assertEqual(cert.get_comment(), 'new_comment') cert.set_comment(b'new_comment') self.assertEqual(cert.get_comment(), 'new_comment') if self.x509_supported: for fmt in ('openssh', 'der', 'pem'): cert = asyncssh.import_certificate( self.rootx509.export_certificate(fmt)) self.assertEqual(cert.get_comment(), None) cert = self.privca.generate_x509_ca_certificate( self.pubkey, 'OU=root', comment='cert_comment') self.assertEqual(cert.get_comment(), 'cert_comment') cert = asyncssh.import_certificate( self.userx509.export_certificate(fmt)) self.assertEqual(cert.get_comment(), 'comment') cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', comment='cert_comment') self.assertEqual(cert.get_comment(), 'cert_comment') cert = asyncssh.import_certificate( self.hostx509.export_certificate(fmt)) self.assertEqual(cert.get_comment(), 'comment') cert = self.privca.generate_x509_host_certificate( self.pubkey, 'OU=host', 'OU=root', comment='cert_comment') self.assertEqual(cert.get_comment(), 'cert_comment') cert.set_comment('new_comment') self.assertEqual(cert.get_comment(), 'new_comment') cert.set_comment(b'new_comment') self.assertEqual(cert.get_comment(), 'new_comment') with self.assertRaises(asyncssh.KeyImportError): self.privkey.set_comment(b'\xff') with self.assertRaises(asyncssh.KeyImportError): self.pubkey.set_comment(b'\xff') with self.assertRaises(asyncssh.KeyImportError): self.usercert.set_comment(b'\xff') if self.x509_supported: with self.assertRaises(asyncssh.KeyImportError): self.userx509.set_comment(b'\xff') keypair = asyncssh.load_keypairs([self.privkey])[0] keypair.set_comment('new_comment') self.assertEqual(keypair.get_comment(), 'new_comment') keypair.set_comment(b'new_comment') self.assertEqual(keypair.get_comment(), 'new_comment') with self.assertRaises(asyncssh.KeyImportError): keypair.set_comment(b'\xff') def check_pkcs1_private(self): """Check PKCS#1 private key format""" with self.subTest('Import PKCS#1 PEM private'): self.import_pkcs1_private('pem') with self.subTest('Export PKCS#1 PEM private'): self.export_pkcs1_private('pem') with self.subTest('Import PKCS#1 DER private'): self.import_pkcs1_private('der') with self.subTest('Export PKCS#1 DER private'): self.export_pkcs1_private('der') for cipher, args in pkcs1_ciphers: with self.subTest('Import PKCS#1 PEM private (%s)' % cipher): self.import_pkcs1_private('pem', cipher, args) with self.subTest('Export PKCS#1 PEM private (%s)' % cipher): self.export_pkcs1_private('pem', cipher) def check_pkcs1_public(self): """Check PKCS#1 public key format""" with self.subTest('Import PKCS#1 PEM public'): self.import_pkcs1_public('pem') with self.subTest('Export PKCS#1 PEM public'): self.export_pkcs1_public('pem') with self.subTest('Import PKCS#1 DER public'): self.import_pkcs1_public('der') with self.subTest('Export PKCS#1 DER public'): self.export_pkcs1_public('der') def check_pkcs8_private(self): """Check PKCS#8 private key format""" with self.subTest('Import PKCS#8 PEM private'): self.import_pkcs8_private('pem', _openssl_available) with self.subTest('Export PKCS#8 PEM private'): self.export_pkcs8_private('pem', _openssl_available) with self.subTest('Import PKCS#8 DER private'): self.import_pkcs8_private('der', _openssl_available) with self.subTest('Export PKCS#8 DER private'): self.export_pkcs8_private('der', _openssl_available) for cipher, hash_alg, pbe_version, args, use_openssl in pkcs8_ciphers: with self.subTest('Import PKCS#8 PEM private (%s-%s-v%s)' % (cipher, hash_alg, pbe_version)): self.import_pkcs8_private('pem', use_openssl, cipher, hash_alg, pbe_version, args) with self.subTest('Export PKCS#8 PEM private (%s-%s-v%s)' % (cipher, hash_alg, pbe_version)): self.export_pkcs8_private('pem', use_openssl, cipher, hash_alg, pbe_version) with self.subTest('Import PKCS#8 DER private (%s-%s-v%s)' % (cipher, hash_alg, pbe_version)): self.import_pkcs8_private('der', use_openssl, cipher, hash_alg, pbe_version, args) with self.subTest('Export PKCS#8 DER private (%s-%s-v%s)' % (cipher, hash_alg, pbe_version)): self.export_pkcs8_private('der', use_openssl, cipher, hash_alg, pbe_version) def check_pkcs8_public(self): """Check PKCS#8 public key format""" with self.subTest('Import PKCS#8 PEM public'): self.import_pkcs8_public('pem') with self.subTest('Export PKCS#8 PEM public'): self.export_pkcs8_public('pem') with self.subTest('Import PKCS#8 DER public'): self.import_pkcs8_public('der') with self.subTest('Export PKCS#8 DER public'): self.export_pkcs8_public('der') def check_openssh_private(self): """Check OpenSSH private key format""" with self.subTest('Import OpenSSH private'): self.import_openssh_private(_openssh_available) with self.subTest('Export OpenSSH private'): self.export_openssh_private(_openssh_available) if bcrypt_available: # pragma: no branch for cipher, use_openssh in openssh_ciphers: with self.subTest('Import OpenSSH private (%s)' % cipher): self.import_openssh_private(use_openssh, cipher) with self.subTest('Export OpenSSH private (%s)' % cipher): self.export_openssh_private(use_openssh, cipher) def check_openssh_public(self): """Check OpenSSH public key format""" with self.subTest('Import OpenSSH public'): self.import_openssh_public() with self.subTest('Export OpenSSH public'): self.export_openssh_public() def check_openssh_certificate(self): """Check OpenSSH certificate format""" with self.subTest('Import OpenSSH user certificate'): self.import_openssh_certificate(CERT_TYPE_USER, 'usercert') with self.subTest('Export OpenSSH user certificate'): self.export_openssh_certificate(CERT_TYPE_USER, self.usercert) with self.subTest('Import OpenSSH host certificate'): self.import_openssh_certificate(CERT_TYPE_HOST, 'hostcert') with self.subTest('Export OpenSSH host certificate'): self.export_openssh_certificate(CERT_TYPE_HOST, self.hostcert) def check_rfc4716_public(self): """Check RFC4716 public key format""" with self.subTest('Import RFC4716 public'): self.import_rfc4716_public() with self.subTest('Export RFC4716 public'): self.export_rfc4716_public() def check_rfc4716_certificate(self): """Check RFC4716 certificate format""" with self.subTest('Import RFC4716 user certificate'): self.import_rfc4716_certificate(CERT_TYPE_USER, 'usercert') with self.subTest('Export RFC4716 user certificate'): self.export_rfc4716_certificate(CERT_TYPE_USER, self.usercert) with self.subTest('Import RFC4716 host certificate'): self.import_rfc4716_certificate(CERT_TYPE_HOST, 'hostcert') with self.subTest('Export RFC4716 host certificate'): self.export_rfc4716_certificate(CERT_TYPE_HOST, self.hostcert) def check_der_x509_certificate(self): """Check DER X.509 certificate format""" with self.subTest('Import DER X.509 user certificate'): self.import_der_x509_certificate(CERT_TYPE_USER, 'userx509') with self.subTest('Export DER X.509 user certificate'): self.export_der_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Import DER X.509 host certificate'): self.import_der_x509_certificate(CERT_TYPE_HOST, 'hostx509') with self.subTest('Export DER X.509 host certificate'): self.export_der_x509_certificate(CERT_TYPE_HOST, self.hostx509) def check_pem_x509_certificate(self): """Check PEM X.509 certificate format""" with self.subTest('Import PEM X.509 user certificate'): self.import_pem_x509_certificate(CERT_TYPE_USER, 'userx509') with self.subTest('Export PEM X.509 user certificate'): self.export_pem_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Import PEM X.509 host certificate'): self.import_pem_x509_certificate(CERT_TYPE_HOST, 'hostx509') with self.subTest('Export PEM X.509 host certificate'): self.export_pem_x509_certificate(CERT_TYPE_HOST, self.hostx509) def check_openssh_x509_certificate(self): """Check OpenSSH X.509 certificate format""" with self.subTest('Import OpenSSH X.509 user certificate'): self.import_openssh_x509_certificate(CERT_TYPE_USER, 'userx509') with self.subTest('Export OpenSSH X.509 user certificate'): self.export_openssh_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Import OpenSSH X.509 host certificate'): self.import_openssh_x509_certificate(CERT_TYPE_HOST, 'hostx509') with self.subTest('Export OpenSSH X.509 host certificate'): self.export_openssh_x509_certificate(CERT_TYPE_HOST, self.hostx509) def check_certificate_options(self): """Check SSH certificate options""" cert = self.privca.generate_user_certificate( self.pubkey, 'name', force_command='command', source_address=['1.2.3.4'], permit_x11_forwarding=False, permit_agent_forwarding=False, permit_port_forwarding=False, permit_pty=False, permit_user_rc=False) cert.write_certificate('cert') self.check_certificate(CERT_TYPE_USER, 'openssh') for valid_after, valid_before in ((0, 1.), (datetime.now(), '+1m'), ('20160101', '20160102'), ('20160101000000', '20160102235959'), ('now', '1w2d3h4m5s'), ('-52w', '+52w')): cert = self.privca.generate_host_certificate( self.pubkey, 'name', valid_after=valid_after, valid_before=valid_before) cert.write_certificate('cert') cert2 = asyncssh.read_certificate('cert') self.assertEqual(cert2.data, cert.data) def check_certificate_errors(self, cert_type): """Check general and OpenSSH certificate error cases""" with self.subTest('Non-ASCII certificate'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('\u0080\n') with self.subTest('Invalid SSH format'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('xxx\n') with self.subTest('Invalid certificate packetization'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate( b'xxx ' + binascii.b2a_base64(b'\x00')) with self.subTest('Invalid certificate algorithm'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate( b'xxx ' + binascii.b2a_base64(String(b'xxx'))) with self.subTest('Invalid certificate critical option'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={b'xxx': b''}) asyncssh.import_certificate(cert) with self.subTest('Ignored certificate extension'): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), extensions={b'xxx': b''}) self.assertIsNotNone(asyncssh.import_certificate(cert)) with self.subTest('Invalid certificate signature'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), bad_signature=True) asyncssh.import_certificate(cert) with self.subTest('Invalid characters in certificate key ID'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), key_id=b'\xff') asyncssh.import_certificate(cert) with self.subTest('Invalid characters in certificate principal'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, (b'\xff',)) asyncssh.import_certificate(cert) if cert_type == CERT_TYPE_USER: with self.subTest('Invalid characters in force-command'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={'force-command': String(b'\xff')}) asyncssh.import_certificate(cert) with self.subTest('Invalid characters in source-address'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={'source-address': String(b'\xff')}) asyncssh.import_certificate(cert) with self.subTest('Invalid IP network in source-address'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={'source-address': String('1.1.1.256')}) asyncssh.import_certificate(cert) with self.subTest('Invalid certificate type'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(0, self.pubkey, self.privca, ('name',)) asyncssh.import_certificate(cert) with self.subTest('Mismatched certificate type'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',)) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type ^ 3, 'name') with self.subTest('Certificate not yet valid'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), valid_after=0xffffffffffffffff) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type, 'name') with self.subTest('Certificate expired'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), valid_before=0) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type, 'name') with self.subTest('Certificate principal mismatch'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',)) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type, 'name2') for fmt in ('der', 'pem', 'xxx'): with self.subTest('Invalid certificate export format', fmt=fmt): with self.assertRaises(asyncssh.KeyExportError): self.usercert.export_certificate(fmt) def check_x509_certificate_errors(self): """Check X.509 certificate error cases""" with self.subTest('Invalid DER format'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate(b'\x30') with self.subTest('Invalid DER format in certificate list'): with self.assertRaises(asyncssh.KeyImportError): with open('certlist', 'wb') as f: f.write(b'\x30') asyncssh.read_certificate_list('certlist') with self.subTest('Invalid PEM format'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----') with self.subTest('Invalid PEM certificate type'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN XXX CERTIFICATE-----\n' '-----END XXX CERTIFICATE-----\n') with self.subTest('Missing PEM footer'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN CERTIFICATE-----\n') with self.subTest('Invalid PEM Base64'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN CERTIFICATE-----\n' 'X\n' '-----END CERTIFICATE-----\n') with self.subTest('Invalid PEM certificate data'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN CERTIFICATE-----\n' 'XXXX\n' '-----END CERTIFICATE-----\n') with self.subTest('Certificate not yet valid'): with self.assertRaises(ValueError): cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', valid_after=0xfffffffffffffffe) self.validate_x509(cert) with self.subTest('Certificate expired'): with self.assertRaises(ValueError): cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', valid_before=1) self.validate_x509(cert) with self.subTest('Certificate principal mismatch'): with self.assertRaises(ValueError): cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', principals=['name']) self.validate_x509(cert, 'name2') for fmt in ('rfc4716', 'xxx'): with self.subTest('Invalid certificate export format', fmt=fmt): with self.assertRaises(asyncssh.KeyExportError): self.userx509.export_certificate(fmt) with self.subTest('Empty certificate chain'): with self.assertRaises(asyncssh.KeyImportError): decode_ssh_certificate(String('x509v3-ssh-rsa') + UInt32(0) + UInt32(0)) def check_x509_certificate_subject(self): """Check X.509 certificate subject cases""" with self.subTest('Missing certificate subject algorithm'): with self.assertRaises(asyncssh.KeyImportError): import_certificate_subject('xxx') with self.subTest('Unknown certificate subject algorithm'): with self.assertRaises(asyncssh.KeyImportError): import_certificate_subject('xxx subject=OU=name') with self.subTest('Invalid certificate subject'): with self.assertRaises(asyncssh.KeyImportError): import_certificate_subject('x509v3-ssh-rsa xxx') subject = import_certificate_subject('x509v3-ssh-rsa subject=OU=name') self.assertEqual(subject, 'OU=name') def test_keys(self): """Check keys and certificates""" for alg_name, kwargs in self.generate_args: with self.subTest(alg_name=alg_name, **kwargs): self.privkey = asyncssh.generate_private_key( alg_name, comment='comment', **kwargs) self.privkey.write_private_key('priv', self.base_format) self.pubkey = self.privkey.convert_to_public() self.pubkey.write_public_key('pub', self.base_format) self.pubkey.write_public_key('sshpub', 'openssh') self.privca = asyncssh.generate_private_key(alg_name, **kwargs) self.privca.write_private_key('privca', self.base_format) self.pubca = self.privca.convert_to_public() self.pubca.write_public_key('pubca', self.base_format) self.usercert = self.privca.generate_user_certificate( self.pubkey, 'name') self.usercert.write_certificate('usercert') self.hostcert = self.privca.generate_host_certificate( self.pubkey, 'name') self.hostcert.write_certificate('hostcert') for f in ('priv', 'privca'): os.chmod(f, 0o600) self.assertEqual(self.privkey.get_algorithm(), alg_name) self.assertEqual(self.usercert.get_algorithm(), alg_name + '-cert-v01@openssh.com') if self.x509_supported: self.rootx509 = self.privca.generate_x509_ca_certificate( self.pubca, 'OU=root') self.rootx509.write_certificate('rootx509') self.userx509 = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root') self.assertEqual(self.userx509.get_algorithm(), 'x509v3-' + alg_name) self.userx509.write_certificate('userx509') self.hostx509 = self.privca.generate_x509_host_certificate( self.pubkey, 'OU=host', 'OU=root') self.hostx509.write_certificate('hostx509') self.otherx509 = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=other', 'OU=root') self.otherx509.write_certificate('otherx509') self.check_encode_errors() self.check_decode_errors() self.check_sshkey_base_errors() self.check_sign_and_verify() self.check_comment() if 'pkcs1' in self.private_formats: self.check_pkcs1_private() if 'pkcs1' in self.public_formats: self.check_pkcs1_public() if 'pkcs8' in self.private_formats: self.check_pkcs8_private() if 'pkcs8' in self.public_formats: self.check_pkcs8_public() self.check_openssh_private() self.check_openssh_public() self.check_openssh_certificate() self.check_rfc4716_public() self.check_rfc4716_certificate() self.check_certificate_options() for cert_type in (CERT_TYPE_USER, CERT_TYPE_HOST): self.check_certificate_errors(cert_type) if self.x509_supported: self.check_der_x509_certificate() self.check_pem_x509_certificate() self.check_openssh_x509_certificate() self.check_x509_certificate_errors() self.check_x509_certificate_subject() class TestDSA(_TestPublicKey): """Test DSA public keys""" keyclass = 'dsa' base_format = 'pkcs8-pem' private_formats = ('pkcs1', 'pkcs8', 'openssh') public_formats = ('pkcs1', 'pkcs8', 'openssh', 'rfc4716') default_cert_version = 'ssh-dss-cert-v01@openssh.com' x509_supported = x509_available generate_args = (('ssh-dss', {}),) class TestRSA(_TestPublicKey): """Test RSA public keys""" keyclass = 'rsa' base_format = 'pkcs8-pem' private_formats = ('pkcs1', 'pkcs8', 'openssh') public_formats = ('pkcs1', 'pkcs8', 'openssh', 'rfc4716') default_cert_version = 'ssh-rsa-cert-v01@openssh.com' x509_supported = x509_available generate_args = (('ssh-rsa', {'key_size': 1024}), ('ssh-rsa', {'key_size': 2048}), ('ssh-rsa', {'key_size': 3072}), ('ssh-rsa', {'exponent': 3})) class TestEC(_TestPublicKey): """Test elliptic curve public keys""" keyclass = 'ec' base_format = 'pkcs8-pem' private_formats = ('pkcs1', 'pkcs8', 'openssh') public_formats = ('pkcs8', 'openssh', 'rfc4716') x509_supported = x509_available generate_args = (('ecdsa-sha2-nistp256', {}), ('ecdsa-sha2-nistp384', {}), ('ecdsa-sha2-nistp521', {})) @property def default_cert_version(self): """Return default SSH certificate version""" return self.privkey.algorithm.decode('ascii') + '-cert-v01@openssh.com' if libnacl_available: # pragma: no branch class TestEd25519(_TestPublicKey): """Test Ed25519 public keys""" keyclass = 'ed25519' base_format = 'openssh' private_formats = ('openssh') public_formats = ('openssh', 'rfc4716') default_cert_version = 'ssh-ed25519-cert-v01@openssh.com' generate_args = (('ssh-ed25519', {}),) del _TestPublicKey class _TestPublicKeyTopLevel(TempDirTestCase): """Top-level public key module tests""" def test_public_key(self): """Test public key top-level functions""" self.assertIsNotNone(get_public_key_algs()) self.assertIsNotNone(get_certificate_algs()) self.assertEqual(bool(get_x509_certificate_algs()), x509_available) def test_public_key_algorithm_mismatch(self): """Test algorihm mismatch in SSH public key""" privkey = asyncssh.generate_private_key('ssh-rsa') keydata = privkey.export_public_key('openssh') keydata = b'ssh-dss ' + keydata.split(None, 1)[1] with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_public_key(keydata) with open('list', 'wb') as f: f.write(keydata) with self.assertRaises(asyncssh.KeyImportError): asyncssh.read_public_key_list('list') def test_pad_error(self): """Test for missing RFC 1423 padding on PBE decrypt""" with self.assertRaises(asyncssh.KeyEncryptionError): pkcs1_decrypt(b'', b'AES-128-CBC', os.urandom(16), 'x') def test_ec_explicit(self): """Test EC certificate with explcit parameters""" if _openssl_available: # pragma: no branch for curve in ('secp256r1', 'secp384r1', 'secp521r1'): with self.subTest('Import EC key with explicit parameters', curve=curve): run('openssl ecparam -out priv -noout -genkey -name %s ' '-param_enc explicit' % curve) asyncssh.read_private_key('priv') with self.subTest('Import EC key with unknown explicit parameters'): run('openssl ecparam -out priv -noout -genkey -name secp112r1 ' '-param_enc explicit') with self.assertRaises(asyncssh.KeyImportError): asyncssh.read_private_key('priv') def test_generate_errors(self): """Test errors in private key and certificate generation""" for alg_name, kwargs in (('xxx', {}), ('ssh-dss', {'xxx': 0}), ('ssh-rsa', {'xxx': 0}), ('ecdsa-sha2-nistp256', {'xxx': 0}), ('ssh-ed25519', {'xxx': 0})): with self.subTest(alg_name=alg_name, **kwargs): with self.assertRaises(asyncssh.KeyGenerationError): asyncssh.generate_private_key(alg_name, **kwargs) privkey = asyncssh.generate_private_key('ssh-rsa') pubkey = privkey.convert_to_public() privca = asyncssh.generate_private_key('ssh-rsa') with self.assertRaises(asyncssh.KeyGenerationError): privca.generate_user_certificate(pubkey, 'name', version=0) with self.assertRaises(ValueError): privca.generate_user_certificate(pubkey, 'name', valid_after=()) with self.assertRaises(ValueError): privca.generate_user_certificate(pubkey, 'name', valid_after='xxx') with self.assertRaises(ValueError): privca.generate_user_certificate(pubkey, 'name', valid_after='now', valid_before='-1m') with self.assertRaises(ValueError): privca.generate_x509_user_certificate(pubkey, 'OU=user', valid_after=()) with self.assertRaises(ValueError): privca.generate_x509_user_certificate(pubkey, 'OU=user', valid_after='xxx') with self.assertRaises(ValueError): privca.generate_x509_user_certificate(pubkey, 'OU=user', valid_after='now', valid_before='-1m') privca.x509_algorithms = None with self.assertRaises(asyncssh.KeyGenerationError): privca.generate_x509_user_certificate(pubkey, 'OU=user') asyncssh-1.11.1/tests/test_saslprep.py000066400000000000000000000057261320320510200200300ustar00rootroot00000000000000# Copyright (c) 2015 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for SASL string preparation""" import unittest from asyncssh.saslprep import saslprep, SASLPrepError class _TestSASLPrep(unittest.TestCase): """Unit tests for saslprep module""" def test_nonstring(self): """Test passing a non-string value""" with self.assertRaises(TypeError): saslprep(b'xxx') def test_unassigned(self): """Test passing strings with unassigned code points""" for s in ('\u0221', '\u038b', '\u0510', '\u070e', '\u0900', '\u0a00'): with self.assertRaises(SASLPrepError, msg='U+%08x' % ord(s)): saslprep('abc' + s + 'def') def test_map_to_nothing(self): """Test passing strings with characters that map to nothing""" for s in ('\u00ad', '\u034f', '\u1806', '\u200c', '\u2060', '\ufe00'): self.assertEqual(saslprep('abc' + s + 'def'), 'abcdef', msg='U+%08x' % ord(s)) def test_map_to_whitespace(self): """Test passing strings with characters that map to whitespace""" for s in ('\u00a0', '\u1680', '\u2000', '\u202f', '\u205f', '\u3000'): self.assertEqual(saslprep('abc' + s + 'def'), 'abc def', msg='U+%08x' % ord(s)) def test_normalization(self): """Test Unicode normalization form KC conversions""" for (s, n) in (('\u00aa', 'a'), ('\u2168', 'IX')): self.assertEqual(saslprep('abc' + s + 'def'), 'abc' + n + 'def', msg='U+%08x' % ord(s)) def test_prohibited(self): """Test passing strings with prohibited characters""" for s in ('\u0000', '\u007f', '\u0080', '\u06dd', '\u180e', '\u200e', '\u2028', '\u202a', '\u206a', '\u2ff0', '\u2ffb', '\ud800', '\udfff', '\ue000', '\ufdd0', '\ufef9', '\ufffc', '\uffff', '\U0001d173', '\U000E0001', '\U00100000', '\U0010fffd'): with self.assertRaises(SASLPrepError, msg='U+%08x' % ord(s)): saslprep('abc' + s + 'def') def test_bidi(self): """Test passing strings with bidirectional characters""" for s in ('\u05be\u05c0\u05c3\u05d0', # RorAL only 'abc\u00c0\u00c1\u00c2', # L only '\u0627\u0031\u0628'): # Mix of RorAL and other self.assertEqual(saslprep(s), s) with self.assertRaises(SASLPrepError): saslprep('abc\u05be\u05c0\u05c3') # Mix of RorAL and L with self.assertRaises(SASLPrepError): saslprep('\u0627\u0031') # RorAL not at both start & end asyncssh-1.11.1/tests/test_sftp.py000066400000000000000000003132761320320510200171550ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH SFTP client and server""" import asyncio import errno import functools import os import posixpath import shutil import stat import sys import time import unittest from unittest.mock import patch from asyncssh import SFTPError, SFTPAttrs, SFTPVFSAttrs, SFTPName, SFTPServer from asyncssh import SEEK_CUR, SEEK_END from asyncssh import FXP_INIT, FXP_VERSION, FXP_OPEN, FXP_CLOSE from asyncssh import FXP_STATUS, FXP_HANDLE, FXP_DATA, FILEXFER_ATTR_UNDEFINED from asyncssh import FX_OK, FX_PERMISSION_DENIED, FX_FAILURE from asyncssh import scp from asyncssh.misc import python35 from asyncssh.packet import SSHPacket, Byte, String, UInt32 from asyncssh.sftp import LocalFile, SFTPHandler, SFTPServerHandler from .server import ServerTestCase from .util import asynctest def remove(files): """Remove files and directories""" for f in files.split(): try: if os.path.isdir(f) and not os.path.islink(f): shutil.rmtree(f) else: os.remove(f) except OSError: pass def sftp_test(func): """Decorator for running SFTP tests""" @asynctest @functools.wraps(func) def sftp_wrapper(self): """Run a test coroutine after opening an SFTP client""" with (yield from self.connect()) as conn: with (yield from conn.start_sftp_client()) as sftp: yield from asyncio.coroutine(func)(self, sftp) yield from sftp.wait_closed() yield from conn.wait_closed() return sftp_wrapper class _ResetFileHandleServerHandler(SFTPServerHandler): """Reset file handle counter on each request to test handle-in-use check""" @asyncio.coroutine def recv_packet(self): """Reset next handle counter to test handle-in-use check""" self._next_handle = 0 return (yield from super().recv_packet()) class _NonblockingCloseServerHandler(SFTPServerHandler): """Close the SFTP session without responding to a nonblocking close""" @asyncio.coroutine def _process_packet(self, pkttype, pktid, packet): """Close the session when a file close request is received""" if pkttype == FXP_CLOSE: yield from self._cleanup(None) else: yield from super()._process_packet(pkttype, pktid, packet) class _ChrootSFTPServer(SFTPServer): """Return an FTP server with a changed root""" def __init__(self, conn): os.mkdir('chroot') super().__init__(conn, 'chroot') def exit(self): """Clean up the changed root directory""" remove('chroot') class _IOErrorSFTPServer(SFTPServer): """Return an I/O error during file writing""" @asyncio.coroutine def write(self, file_obj, offset, data): """Return an error for writes past 64 KB in a file""" if offset >= 65536: raise SFTPError(FX_FAILURE, 'I/O error') else: super().write(file_obj, offset, data) class _NotImplSFTPServer(SFTPServer): """Return an error that a request is not implemented""" @asyncio.coroutine def symlink(self, old_path, new_path): """Return that symlinks aren't implemented""" raise NotImplementedError class _LongnameSFTPServer(SFTPServer): """Return a fixed set of files in response to a listdir request""" def listdir(self, path): """List the contents of a directory""" return list((b'.', b'..', SFTPName(b'.file'), SFTPName(b'file1'), SFTPName(b'file2', '', SFTPAttrs(permissions=0, nlink=1, uid=0, gid=0, size=0, mtime=0)), SFTPName(b'file3', '', SFTPAttrs(mtime=time.time())), SFTPName(b'file4', 56*b' ' + b'file4'))) def lstat(self, path): """Get attributes of a file, directory, or symlink""" return SFTPAttrs.from_local(super().lstat(path)) class _LargeDirSFTPServer(SFTPServer): """Return a really large listdir result""" @asyncio.coroutine def listdir(self, path): """Return a really large listdir result""" # pylint: disable=unused-argument return 100000 * [SFTPName(b'a', '', SFTPAttrs())] class _StatVFSSFTPServer(SFTPServer): """Return a fixed set of attributes in response to a statvfs request""" expected_statvfs = SFTPVFSAttrs(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) def statvfs(self, path): """Get attributes of the file system containing a file""" # pylint: disable=unused-argument return self.expected_statvfs def fstatvfs(self, file_obj): """Return attributes of the file system containing an open file""" # pylint: disable=unused-argument return self.expected_statvfs class _ChownSFTPServer(SFTPServer): """Simulate file ownership changes""" _ownership = {} def setstat(self, path, attrs): """Set attributes of a file or directory""" # pylint: disable=unused-argument self._ownership[self.map_path(path)] = (attrs.uid, attrs.gid) def stat(self, path): """Get attributes of a file or directory, following symlinks""" # pylint: disable=unused-argument path = self.map_path(path) attrs = SFTPAttrs.from_local(os.stat(path)) if path in self._ownership: # pragma: no branch attrs.uid, attrs.gid = self._ownership[path] return attrs class _SymlinkSFTPServer(SFTPServer): """Implement symlink with non-standard argument order""" def symlink(self, newpath, oldpath): """Create a symbolic link""" return super().symlink(oldpath, newpath) class _SFTPAttrsSFTPServer(SFTPServer): """Implement stat which returns SFTPAttrs and raises SFTPError""" @asyncio.coroutine def stat(self, path): """Get attributes of a file or directory, following symlinks""" try: return SFTPAttrs.from_local(super().stat(path)) except OSError as exc: if exc.errno == errno.EACCES: raise SFTPError(FX_PERMISSION_DENIED, exc.strerror) else: raise SFTPError(FX_FAILURE, exc.strerror) class _AsyncSFTPServer(SFTPServer): """Implement all SFTP callbacks as coroutines""" @asyncio.coroutine def format_longname(self, name): """Format the long name associated with an SFTP name""" return super().format_longname(name) @asyncio.coroutine def open(self, path, pflags, attrs): """Open a file to serve to a remote client""" return super().open(path, pflags, attrs) @asyncio.coroutine def close(self, file_obj): """Close an open file or directory""" super().close(file_obj) @asyncio.coroutine def read(self, file_obj, offset, size): """Read data from an open file""" return super().read(file_obj, offset, size) @asyncio.coroutine def write(self, file_obj, offset, data): """Write data to an open file""" return super().write(file_obj, offset, data) @asyncio.coroutine def lstat(self, path): """Get attributes of a file, directory, or symlink""" return super().lstat(path) @asyncio.coroutine def fstat(self, file_obj): """Get attributes of an open file""" return super().fstat(file_obj) @asyncio.coroutine def setstat(self, path, attrs): """Set attributes of a file or directory""" super().setstat(path, attrs) @asyncio.coroutine def fsetstat(self, file_obj, attrs): """Set attributes of an open file""" super().fsetstat(file_obj, attrs) @asyncio.coroutine def listdir(self, path): """List the contents of a directory""" return super().listdir(path) @asyncio.coroutine def remove(self, path): """Remove a file or symbolic link""" super().remove(path) @asyncio.coroutine def mkdir(self, path, attrs): """Create a directory with the specified attributes""" super().mkdir(path, attrs) @asyncio.coroutine def rmdir(self, path): """Remove a directory""" super().rmdir(path) @asyncio.coroutine def realpath(self, path): """Return the canonical version of a path""" return super().realpath(path) @asyncio.coroutine def stat(self, path): """Get attributes of a file or directory, following symlinks""" return super().stat(path) @asyncio.coroutine def rename(self, oldpath, newpath): """Rename a file, directory, or link""" super().rename(oldpath, newpath) @asyncio.coroutine def readlink(self, path): """Return the target of a symbolic link""" return super().readlink(path) @asyncio.coroutine def symlink(self, oldpath, newpath): """Create a symbolic link""" super().symlink(oldpath, newpath) @asyncio.coroutine def posix_rename(self, oldpath, newpath): """Rename a file, directory, or link with POSIX semantics""" super().posix_rename(oldpath, newpath) @asyncio.coroutine def statvfs(self, path): """Get attributes of the file system containing a file""" return super().statvfs(path) @asyncio.coroutine def fstatvfs(self, file_obj): """Return attributes of the file system containing an open file""" return super().fstatvfs(file_obj) @asyncio.coroutine def link(self, oldpath, newpath): """Create a hard link""" super().link(oldpath, newpath) @asyncio.coroutine def fsync(self, file_obj): """Force file data to be written to disk""" super().fsync(file_obj) @asyncio.coroutine def exit(self): """Shut down this SFTP server""" super().exit() class _CheckSFTP(ServerTestCase): """Utility functions for AsyncSSH SFTP unit tests""" @classmethod def setUpClass(cls): """Check if symlink is available on this platform""" super().setUpClass() try: os.symlink('file', 'link') os.remove('link') cls._symlink_supported = True except OSError: # pragma: no cover cls._symlink_supported = False def _create_file(self, name, data=(), mode=None, utime=None): """Create a test file""" if data is (): data = str(id(self)) with open(name, 'w') as f: f.write(data) if mode is not None: os.chmod(name, mode) if utime is not None: os.utime(name, utime) def _check_attr(self, name1, name2, follow_symlinks, check_atime): """Check if attributes on two files are equal""" statfunc = os.stat if follow_symlinks else os.lstat attrs1 = statfunc(name1) attrs2 = statfunc(name2) self.assertEqual(stat.S_IMODE(attrs1.st_mode), stat.S_IMODE(attrs2.st_mode)) self.assertEqual(int(attrs1.st_mtime), int(attrs2.st_mtime)) if check_atime: self.assertEqual(int(attrs1.st_atime), int(attrs2.st_atime)) def _check_file(self, name1, name2, preserve=False, follow_symlinks=False, check_atime=True): """Check if two files are equal""" if preserve: self._check_attr(name1, name2, follow_symlinks, check_atime) with open(name1) as file1: with open(name2) as file2: self.assertEqual(file1.read(), file2.read()) def _check_stat(self, sftp_stat, local_stat): """Check if file attributes are equal""" self.assertEqual(sftp_stat.size, local_stat.st_size) self.assertEqual(sftp_stat.uid, local_stat.st_uid) self.assertEqual(sftp_stat.gid, local_stat.st_gid) self.assertEqual(sftp_stat.permissions, local_stat.st_mode) self.assertEqual(sftp_stat.atime, int(local_stat.st_atime)) self.assertEqual(sftp_stat.mtime, int(local_stat.st_mtime)) def _check_link(self, link, target): """Check if a symlink points to the right target""" self.assertEqual(os.readlink(link), target) class _TestSFTP(_CheckSFTP): """Unit tests for AsyncSSH SFTP client and server""" # pylint: disable=too-many-public-methods @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server for the tests to use""" return (yield from cls.create_server(sftp_factory=True)) @sftp_test def test_copy(self, sftp): """Test copying a file over SFTP""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: self._create_file('src') yield from getattr(sftp, method)('src', 'dst') self._check_file('src', 'dst') finally: remove('src dst') @sftp_test def test_copy_progress(self, sftp): """Test copying a file over SFTP with progress reporting""" def _report_progress(srcpath, dstpath, bytes_copied, total_bytes): """Monitor progress of copy""" # pylint: disable=unused-argument reports.append(bytes_copied) for method in ('get', 'put', 'copy'): reports = [] with self.subTest(method=method): try: self._create_file('src', 100000*'a') yield from getattr(sftp, method)( 'src', 'dst', block_size=8192, progress_handler=_report_progress) self._check_file('src', 'dst') self.assertEqual(len(reports), 13) self.assertEqual(reports[-1], 100000) finally: remove('src dst') @sftp_test def test_copy_preserve(self, sftp): """Test copying a file with preserved attributes over SFTP""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: self._create_file('src', mode=0o666, utime=(1, 2)) yield from getattr(sftp, method)('src', 'dst', preserve=True) self._check_file('src', 'dst', preserve=True) finally: remove('src dst') @sftp_test def test_copy_recurse(self, sftp): """Test recursively copying a directory over SFTP""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: os.mkdir('src') self._create_file('src/file1') if self._symlink_supported: # pragma: no branch os.symlink('file1', 'src/file2') yield from getattr(sftp, method)('src', 'dst', recurse=True) self._check_file('src/file1', 'dst/file1') if self._symlink_supported: # pragma: no branch self._check_link('dst/file2', 'file1') finally: remove('src dst') @sftp_test def test_copy_recurse_existing(self, sftp): """Test recursively copying over SFTP where target dir exists""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: os.mkdir('src') os.mkdir('dst') os.mkdir('dst/src') self._create_file('src/file1') if self._symlink_supported: # pragma: no branch os.symlink('file1', 'src/file2') yield from getattr(sftp, method)('src', 'dst', recurse=True) self._check_file('src/file1', 'dst/src/file1') if self._symlink_supported: # pragma: no branch self._check_link('dst/src/file2', 'file1') finally: remove('src dst') @sftp_test def test_copy_follow_symlinks(self, sftp): """Test copying a file over SFTP while following symlinks""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: self._create_file('src') os.symlink('src', 'link') yield from getattr(sftp, method)('link', 'dst', follow_symlinks=True) self._check_file('src', 'dst') finally: remove('src dst link') @sftp_test def test_copy_invalid_name(self, sftp): """Test copying a file with an invalid name over SFTP""" for method in ('get', 'put', 'copy', 'mget', 'mput', 'mcopy'): with self.subTest(method=method): with self.assertRaises((FileNotFoundError, SFTPError, UnicodeDecodeError)): yield from getattr(sftp, method)(b'\xff') @sftp_test def test_copy_directory_no_recurse(self, sftp): """Test copying a directory over SFTP without recurse option""" for method in ('get', 'put', 'copy', 'mget', 'mput', 'mcopy'): with self.subTest(method=method): try: os.mkdir('dir') with self.assertRaises(SFTPError): yield from getattr(sftp, method)('dir') finally: remove('dir') @sftp_test def test_multiple_copy(self, sftp): """Test copying multiple files over SFTP""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') yield from getattr(sftp, method)('src*', 'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test def test_multiple_copy_bytes_path(self, sftp): """Test copying multiple files with byte string paths over SFTP""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') yield from getattr(sftp, method)(b'src*', b'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test def test_multiple_copy_target_not_dir(self, sftp): """Test copying multiple files over SFTP with non-directory target""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src') with self.assertRaises(SFTPError): yield from getattr(sftp, method)('src', 'dst') finally: remove('src') @sftp_test def test_multiple_copy_error_handler(self, sftp): """Test copying multiple files over SFTP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" self.assertEqual(exc.reason, 'src2 is a directory') for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1') os.mkdir('src2') os.mkdir('dst') yield from getattr(sftp, method)('src*', 'dst', error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') @sftp_test def test_glob(self, sftp): """Test a glob pattern match over SFTP""" try: os.mkdir('filedir') self._create_file('file1') self._create_file('filedir/file2') self._create_file('filedir/file3') self.assertEqual(sorted((yield from sftp.glob('file*'))), ['file1', 'filedir']) self.assertEqual(sorted((yield from sftp.glob('./file*'))), ['./file1', './filedir']) self.assertEqual(sorted((yield from sftp.glob(b'file*'))), [b'file1', b'filedir']) self.assertEqual(sorted((yield from sftp.glob(['file*']))), ['file1', 'filedir']) self.assertEqual(sorted((yield from sftp.glob(['', 'file*']))), ['file1', 'filedir']) self.assertEqual(sorted((yield from sftp.glob(['file*/*2']))), ['filedir/file2']) self.assertEqual((yield from sftp.glob([b'fil*1', 'fil*dir'])), [b'file1', 'filedir']) finally: remove('file1 filedir') @sftp_test def test_glob_error(self, sftp): """Test a glob pattern match error over SFTP""" with self.assertRaises(SFTPError): yield from sftp.glob('file*') @sftp_test def test_glob_error_handler(self, sftp): """Test a glob pattern match with error handler over SFTP""" def err_handler(exc): """Catch error for nonexistent file1""" self.assertEqual(exc.reason, 'No matches found') try: self._create_file('file2') self.assertEqual((yield from sftp.glob(['file1*', 'file2*'], error_handler=err_handler)), ['file2']) finally: remove('file2') @sftp_test def test_stat(self, sftp): """Test getting attributes on a file""" try: os.mkdir('dir') self._create_file('file') if self._symlink_supported: # pragma: no branch os.symlink('bad', 'badlink') os.symlink('dir', 'dirlink') os.symlink('file', 'filelink') self._check_stat((yield from sftp.stat('dir')), os.stat('dir')) self._check_stat((yield from sftp.stat('file')), os.stat('file')) if self._symlink_supported: # pragma: no branch self._check_stat((yield from sftp.stat('dirlink')), os.stat('dir')) self._check_stat((yield from sftp.stat('filelink')), os.stat('file')) with self.assertRaises(SFTPError): yield from sftp.stat('badlink') # pragma: no branch self.assertTrue((yield from sftp.isdir('dir'))) self.assertFalse((yield from sftp.isdir('file'))) if self._symlink_supported: # pragma: no branch self.assertFalse((yield from sftp.isdir('badlink'))) self.assertTrue((yield from sftp.isdir('dirlink'))) self.assertFalse((yield from sftp.isdir('filelink'))) self.assertFalse((yield from sftp.isfile('dir'))) self.assertTrue((yield from sftp.isfile('file'))) if self._symlink_supported: # pragma: no branch self.assertFalse((yield from sftp.isfile('badlink'))) self.assertFalse((yield from sftp.isfile('dirlink'))) self.assertTrue((yield from sftp.isfile('filelink'))) self.assertFalse((yield from sftp.islink('dir'))) self.assertFalse((yield from sftp.islink('file'))) if self._symlink_supported: # pragma: no branch self.assertTrue((yield from sftp.islink('badlink'))) self.assertTrue((yield from sftp.islink('dirlink'))) self.assertTrue((yield from sftp.islink('filelink'))) finally: remove('dir file badlink dirlink filelink') @sftp_test def test_lstat(self, sftp): """Test getting attributes on a link""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('file', 'link') self._check_stat((yield from sftp.lstat('link')), os.lstat('link')) finally: remove('link') @sftp_test def test_setstat(self, sftp): """Test setting attributes on a file""" try: self._create_file('file') yield from sftp.setstat('file', SFTPAttrs(permissions=0o666)) self.assertEqual(stat.S_IMODE(os.stat('file').st_mode), 0o666) finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') @sftp_test def test_statvfs(self, sftp): """Test getting attributes on a filesystem We can't compare the values returned by a live statvfs call since they can change at any time. See the separate _TestSFTStatPVFS class for a more complete test, but this is left in for code coverage purposes. """ self.assertIsInstance((yield from sftp.statvfs('.')), SFTPVFSAttrs) @unittest.skipIf(sys.platform == 'win32' and not python35, 'skip truncate tests on Windows before Python 3.5') @sftp_test def test_truncate(self, sftp): """Test truncating a file""" try: self._create_file('file', '01234567890123456789') yield from sftp.truncate('file', 10) self.assertEqual((yield from sftp.getsize('file')), 10) with open('file') as localf: self.assertEqual(localf.read(), '0123456789') finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chown tests on Windows') @sftp_test def test_chown(self, sftp): """Test changing ownership of a file We can't change to a different user/group here if we're not root, so just change to the same user/group. See the separate _TestSFTPChown class for a more complete test, but this is left in for code coverage purposes. """ try: self._create_file('file') attrs = os.stat('file') yield from sftp.chown('file', attrs.st_uid, attrs.st_gid) new_attrs = os.stat('file') self.assertEqual(new_attrs.st_uid, attrs.st_uid) self.assertEqual(new_attrs.st_gid, attrs.st_gid) finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chmod tests on Windows') @sftp_test def test_chmod(self, sftp): """Test changing permissions on a file""" try: self._create_file('file') yield from sftp.chmod('file', 0o1234) self.assertEqual(stat.S_IMODE(os.stat('file').st_mode), 0o1234) finally: remove('file') @sftp_test def test_utime(self, sftp): """Test changing access and modify times on a file""" try: self._create_file('file') yield from sftp.utime('file') yield from sftp.utime('file', (1, 2)) attrs = os.stat('file') self.assertEqual(attrs.st_atime, 1) self.assertEqual(attrs.st_mtime, 2) self.assertEqual((yield from sftp.getatime('file')), 1) self.assertEqual((yield from sftp.getmtime('file')), 2) finally: remove('file') @sftp_test def test_exists(self, sftp): """Test checking whether a file exists""" try: self._create_file('file1') self.assertTrue((yield from sftp.exists('file1'))) self.assertFalse((yield from sftp.exists('file2'))) with self.assertRaises(SFTPError): yield from sftp.exists(65536*'a') finally: remove('file1') @sftp_test def test_lexists(self, sftp): """Test checking whether a link exists""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('file', 'link1') self.assertTrue((yield from sftp.lexists('link1'))) self.assertFalse((yield from sftp.lexists('link2'))) finally: remove('link1') @sftp_test def test_remove(self, sftp): """Test removing a file""" try: self._create_file('file') yield from sftp.remove('file') with self.assertRaises(FileNotFoundError): os.stat('file') # pragma: no branch with self.assertRaises(SFTPError): yield from sftp.remove('file') finally: remove('file') @sftp_test def test_unlink(self, sftp): """Test unlinking a file""" try: self._create_file('file') yield from sftp.unlink('file') with self.assertRaises(FileNotFoundError): os.stat('file') # pragma: no branch with self.assertRaises(SFTPError): yield from sftp.unlink('file') finally: remove('file') @sftp_test def test_rename(self, sftp): """Test renaming a file""" try: self._create_file('file1') self._create_file('file2') with self.assertRaises(SFTPError): yield from sftp.rename('file1', 'file2') # pragma: no branch yield from sftp.rename('file1', 'file3') self.assertTrue(os.path.exists('file3')) finally: remove('file1 file2 file3') @sftp_test def test_posix_rename(self, sftp): """Test renaming a file that replaces a target file""" try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') yield from sftp.posix_rename('file1', 'file2') with open('file2') as localf: self.assertEqual(localf.read(), 'xxx') finally: remove('file1 file2') @sftp_test def test_listdir(self, sftp): """Test listing files in a directory""" try: os.mkdir('dir') self._create_file('dir/file1') self._create_file('dir/file2') self.assertEqual(sorted((yield from sftp.listdir('dir'))), ['.', '..', 'file1', 'file2']) finally: remove('dir') @sftp_test def test_listdir_error(self, sftp): """Test error while listing contents of a directory""" @asyncio.coroutine def _readdir_error(self, handle): """Return an error on an SFTP readdir request""" # pylint: disable=unused-argument raise SFTPError(FX_FAILURE, 'I/O error') try: os.mkdir('dir') with patch('asyncssh.sftp.SFTPClientHandler.readdir', _readdir_error): with self.assertRaises(SFTPError): yield from sftp.listdir('dir') finally: remove('dir') @sftp_test def test_mkdir(self, sftp): """Test creating a directory""" try: yield from sftp.mkdir('dir') self.assertTrue(os.path.isdir('dir')) finally: remove('dir') @sftp_test def test_rmdir(self, sftp): """Test removing a directory""" try: os.mkdir('dir') yield from sftp.rmdir('dir') with self.assertRaises(FileNotFoundError): os.stat('dir') finally: remove('dir') @sftp_test def test_readlink(self, sftp): """Test reading a symlink""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('/file', 'link') self.assertEqual((yield from sftp.readlink('link')), '/file') self.assertEqual((yield from sftp.readlink(b'link')), b'/file') finally: remove('link') @sftp_test def test_readlink_decode_error(self, sftp): """Test unicode decode error while reading a symlink""" @asyncio.coroutine def _readlink_error(self, path): """Return invalid unicode on an SFTP readlink request""" # pylint: disable=unused-argument return [SFTPName(b'\xff')] with patch('asyncssh.sftp.SFTPClientHandler.readlink', _readlink_error): with self.assertRaises(SFTPError): yield from sftp.readlink('link') @sftp_test def test_symlink(self, sftp): """Test creating a symlink""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: yield from sftp.symlink('file', 'link') self._check_link('link', 'file') finally: remove('file link') @asynctest def test_symlink_encode_error(self): """Test creating a unicode symlink with no path encoding set""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') with (yield from self.connect()) as conn: sftp = yield from conn.start_sftp_client(path_encoding=None) with sftp: with self.assertRaises(SFTPError): yield from sftp.symlink('file', 'link') yield from sftp.wait_closed() yield from conn.wait_closed() @asynctest def test_nonstandard_symlink_client(self): """Test creating a symlink with opposite argument order""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: with (yield from self.connect(client_version='OpenSSH')) as conn: with (yield from conn.start_sftp_client()) as sftp: yield from sftp.symlink('link', 'file') self._check_link('link', 'file') # pragma: no branch yield from sftp.wait_closed() # pragma: no branch yield from conn.wait_closed() finally: remove('file link') @sftp_test def test_link(self, sftp): """Test creating a hard link""" try: self._create_file('file1') yield from sftp.link('file1', 'file2') self._check_file('file1', 'file2') finally: remove('file1 file2') @sftp_test def test_open_read(self, sftp): """Test reading data from a file""" f = None try: self._create_file('file', 'xxx') f = yield from sftp.open('file') self.assertEqual((yield from f.read()), 'xxx') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_read_bytes(self, sftp): """Test reading bytes from a file""" f = None try: self._create_file('file', 'xxx') f = yield from sftp.open('file', 'rb') self.assertEqual((yield from f.read()), b'xxx') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_read_offset_size(self, sftp): """Test reading at a specific offset and size""" f = None try: self._create_file('file', 'xxxxyyyy') f = yield from sftp.open('file') self.assertEqual((yield from f.read(4, 2)), 'xxyy') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_read_nonexistent(self, sftp): """Test reading data from a nonexistent file""" f = None try: with self.assertRaises(SFTPError): f = yield from sftp.open('file') finally: if f: # pragma: no cover yield from f.close() @unittest.skipIf(sys.platform == 'win32', 'skip permission tests on Windows') @sftp_test def test_open_read_not_permitted(self, sftp): """Test reading data from a file with no read permission""" f = None try: self._create_file('file', mode=0) with self.assertRaises(SFTPError): f = yield from sftp.open('file') finally: if f: # pragma: no cover yield from f.close() remove('file') @sftp_test def test_open_write(self, sftp): """Test writing data to a file""" f = None try: f = yield from sftp.open('file', 'w') yield from f.write('xxx') yield from f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_write_bytes(self, sftp): """Test writing bytes to a file""" f = None try: f = yield from sftp.open('file', 'wb') yield from f.write(b'xxx') yield from f.close() with open('file', 'rb') as localf: self.assertEqual(localf.read(), b'xxx') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_truncate(self, sftp): """Test truncating a file at open time""" f = None try: self._create_file('file', 'xxxyyy') f = yield from sftp.open('file', 'w') yield from f.write('zzz') yield from f.close() with open('file') as localf: self.assertEqual(localf.read(), 'zzz') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_append(self, sftp): """Test appending data to an existing file""" f = None try: self._create_file('file', 'xxx') f = yield from sftp.open('file', 'a+') yield from f.write('yyy') self.assertEqual((yield from f.read()), '') yield from f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxyyy') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_exclusive_create(self, sftp): """Test creating a new file""" f = None try: f = yield from sftp.open('file', 'x') yield from f.write('xxx') yield from f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') # pragma: no branch with self.assertRaises(SFTPError): f = yield from sftp.open('file', 'x') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_exclusive_create_existing(self, sftp): """Test exclusive create of an existing file""" f = None try: self._create_file('file') with self.assertRaises(SFTPError): f = yield from sftp.open('file', 'x') finally: if f: # pragma: no cover yield from f.close() remove('file') @sftp_test def test_open_overwrite(self, sftp): """Test overwriting part of an existing file""" f = None try: self._create_file('file', 'xxxyyy') f = yield from sftp.open('file', 'r+') yield from f.write('zzz') yield from f.close() with open('file') as localf: self.assertEqual(localf.read(), 'zzzyyy') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_overwrite_offset_size(self, sftp): """Test writing data at a specific offset""" f = None try: self._create_file('file', 'xxxxyyyy') f = yield from sftp.open('file', 'r+') yield from f.write('zz', 3) yield from f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxzzyyy') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_open_overwrite_nonexistent(self, sftp): """Test overwriting a nonexistent file""" f = None try: with self.assertRaises(SFTPError): f = yield from sftp.open('file', 'r+') finally: if f: # pragma: no cover yield from f.close() @sftp_test def test_file_seek(self, sftp): """Test seeking within a file""" f = None try: f = yield from sftp.open('file', 'w+') yield from f.write('xxxxyyyy') yield from f.seek(3) yield from f.write('zz') yield from f.seek(-3, SEEK_CUR) self.assertEqual((yield from f.read(4)), 'xzzy') yield from f.seek(-4, SEEK_END) self.assertEqual((yield from f.read()), 'zyyy') self.assertEqual((yield from f.read()), '') self.assertEqual((yield from f.read(1)), '') with self.assertRaises(ValueError): yield from f.seek(0, -1) yield from f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxzzyyy') finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_file_stat(self, sftp): """Test getting attributes on an open file""" f = None try: self._create_file('file') f = yield from sftp.open('file') self._check_stat((yield from f.stat()), os.stat('file')) finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_file_setstat(self, sftp): """Test setting attributes on an open file""" f = None try: self._create_file('file') attrs = SFTPAttrs(permissions=0o666) f = yield from sftp.open('file') yield from f.setstat(attrs) yield from f.close() self.assertEqual(stat.S_IMODE(os.stat('file').st_mode), 0o666) finally: if f: # pragma: no branch yield from f.close() remove('file') @unittest.skipIf(sys.platform == 'win32' and not python35, 'skip truncate tests on Windows before Python 3.5') @sftp_test def test_file_truncate(self, sftp): """Test truncating an open file""" f = None try: self._create_file('file', '01234567890123456789') f = yield from sftp.open('file', 'a+') yield from f.truncate(10) self.assertEqual((yield from f.tell()), 10) self.assertEqual((yield from f.read(offset=0)), '0123456789') self.assertEqual((yield from f.tell()), 10) finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_file_utime(self, sftp): """Test changing access and modify times on an open file""" f = None try: self._create_file('file') f = yield from sftp.open('file') yield from f.utime() yield from f.utime((1, 2)) yield from f.close() attrs = os.stat('file') self.assertEqual(attrs.st_atime, 1) self.assertEqual(attrs.st_mtime, 2) finally: if f: # pragma: no branch yield from f.close() remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') @sftp_test def test_file_statvfs(self, sftp): """Test getting attributes on the filesystem containing an open file We can't compare the values returned by a live statvfs call since they can change at any time. See the separate _TestSFTStatPVFS class for a more complete test, but this is left in for code coverage purposes. """ f = None try: self._create_file('file') f = yield from sftp.open('file') self.assertIsInstance((yield from f.statvfs()), SFTPVFSAttrs) finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_file_sync(self, sftp): """Test file sync""" f = None try: f = yield from sftp.open('file', 'w') self.assertIsNone((yield from f.fsync())) finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_exited_session(self, sftp): """Test use of SFTP session after exit""" sftp.exit() yield from sftp.wait_closed() f = None try: with self.assertRaises(SFTPError): f = yield from sftp.open('file') finally: if f: # pragma: no cover yield from f.close() @sftp_test def test_cleanup_open_files(self, sftp): """Test cleanup of open file handles on exit""" try: self._create_file('file') yield from sftp.open('file') finally: sftp.exit() yield from sftp.wait_closed() remove('file') @sftp_test def test_invalid_open_mode(self, sftp): """Test opening file with invalid mode""" with self.assertRaises(ValueError): yield from sftp.open('file', 'z') @sftp_test def test_invalid_handle(self, sftp): """Test sending requests associated with an invalid file handle""" @asyncio.coroutine def _return_invalid_handle(self, path, pflags, attrs): """Return an invalid file handle""" # pylint: disable=unused-argument return UInt32(0xffffffff) with patch('asyncssh.sftp.SFTPClientHandler.open', _return_invalid_handle): f = yield from sftp.open('file') with self.assertRaises(SFTPError): yield from f.read() with self.assertRaises(SFTPError): yield from f.read(1) with self.assertRaises(SFTPError): yield from f.write('') with self.assertRaises(SFTPError): yield from f.stat() with self.assertRaises(SFTPError): yield from f.setstat(SFTPAttrs()) with self.assertRaises(SFTPError): yield from f.statvfs() with self.assertRaises(SFTPError): yield from f.fsync() with self.assertRaises(SFTPError): yield from f.close() @sftp_test def test_closed_file(self, sftp): """Test I/O operations on a closed file""" f = None try: self._create_file('file') with (yield from sftp.open('file')) as f: # Do an explicit close to test double-close yield from f.close() with self.assertRaises(ValueError): yield from f.read() # pragma: no branch with self.assertRaises(ValueError): yield from f.write('') # pragma: no branch with self.assertRaises(ValueError): yield from f.seek(0) # pragma: no branch with self.assertRaises(ValueError): yield from f.tell() # pragma: no branch with self.assertRaises(ValueError): yield from f.stat() # pragma: no branch with self.assertRaises(ValueError): yield from f.setstat(SFTPAttrs()) # pragma: no branch with self.assertRaises(ValueError): yield from f.statvfs() # pragma: no branch with self.assertRaises(ValueError): yield from f.truncate() # pragma: no branch with self.assertRaises(ValueError): yield from f.chown(0, 0) # pragma: no branch with self.assertRaises(ValueError): yield from f.chmod(0) # pragma: no branch with self.assertRaises(ValueError): yield from f.utime() # pragma: no branch with self.assertRaises(ValueError): yield from f.fsync() # pragma: no branch finally: if f: # pragma: no branch yield from f.close() remove('file') @sftp_test def test_exit_after_nonblocking_close(self, sftp): """Test exit before receiving reply to a non-blocking close""" # pylint: disable=no-self-use # We don't clean up this file, as it's still open when we exit with (yield from sftp.open('nonblocking_file', 'w')): pass def test_immediate_client_close(self): """Test closing SFTP channel immediately after opening""" @asyncio.coroutine def _closing_start(self): """Immediately close the SFTP channel""" self.exit() with patch('asyncssh.sftp.SFTPClientHandler.start', _closing_start): sftp_test(lambda self, sftp: None)(self) def test_no_init(self): """Test sending non-init request at start""" @asyncio.coroutine def _no_init_start(self): """Send a non-init request at start""" self.send_packet(Byte(FXP_OPEN), UInt32(0)) with patch('asyncssh.sftp.SFTPClientHandler.start', _no_init_start): sftp_test(lambda self, sftp: None)(self) def test_missing_version(self): """Test sending init with missing version""" @asyncio.coroutine def _missing_version_start(self): """Send an init request with missing version""" self.send_packet(Byte(FXP_INIT)) with patch('asyncssh.sftp.SFTPClientHandler.start', _missing_version_start): sftp_test(lambda self, sftp: None)(self) def test_nonstandard_version(self): """Test sending init with non-standard version""" # pylint: disable=no-self-use with patch('asyncssh.sftp._SFTP_VERSION', 4): sftp_test(lambda self, sftp: None)(self) def test_non_version_response(self): """Test sending a non-version message in response to init""" @asyncio.coroutine def _non_version_response(self): """Send a non-version response to init""" packet = yield from SFTPHandler.recv_packet(self) self.send_packet(Byte(FXP_STATUS)) return packet with patch('asyncssh.sftp.SFTPServerHandler.recv_packet', _non_version_response): with self.assertRaises(SFTPError): sftp_test(lambda self, sftp: None)(self) # pragma: no branch def test_unsupported_version_response(self): """Test sending an unsupported version in response to init""" @asyncio.coroutine def _unsupported_version_response(self): """Send an unsupported version in response to init""" packet = yield from SFTPHandler.recv_packet(self) self.send_packet(Byte(FXP_VERSION), UInt32(4)) return packet with patch('asyncssh.sftp.SFTPServerHandler.recv_packet', _unsupported_version_response): with self.assertRaises(SFTPError): sftp_test(lambda self, sftp: None)(self) # pragma: no branch def test_unknown_extension_response(self): """Test sending an unknown extension in version response""" with patch('asyncssh.sftp.SFTPServerHandler._extensions', [(b'xxx', b'1')]): sftp_test(lambda self, sftp: None)(self) def test_close_after_init(self): """Test close immediately after init request at start""" @asyncio.coroutine def _close_after_init_start(self): """Send a close immediately after init request at start""" self.send_packet(Byte(FXP_INIT), UInt32(3)) yield from self._cleanup(None) with patch('asyncssh.sftp.SFTPClientHandler.start', _close_after_init_start): sftp_test(lambda self, sftp: None)(self) def test_file_handle_skip(self): """Test skipping over a file handle already in use""" @asyncio.coroutine def _reset_file_handle(self, sftp): """Open multiple files, resetting next handle each time""" file1 = None file2 = None try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') file1 = yield from sftp.open('file1') file2 = yield from sftp.open('file2') self.assertEqual((yield from file1.read()), 'xxx') self.assertEqual((yield from file2.read()), 'yyy') finally: if file1: # pragma: no branch yield from file1.close() if file2: # pragma: no branch yield from file2.close() remove('file1 file2') with patch('asyncssh.sftp.SFTPServerHandler', _ResetFileHandleServerHandler): sftp_test(_reset_file_handle)(self) @sftp_test def test_missing_request_pktid(self, sftp): """Test sending request without a packet ID""" @asyncio.coroutine def _missing_pktid(self, filename, pflags, attrs): """Send a request without a packet ID""" # pylint: disable=unused-argument self.send_packet(Byte(FXP_OPEN)) with patch('asyncssh.sftp.SFTPClientHandler.open', _missing_pktid): yield from sftp.open('file') @sftp_test def test_malformed_open_request(self, sftp): """Test sending malformed open request""" @asyncio.coroutine def _malformed_open(self, filename, pflags, attrs): """Send a malformed open request""" # pylint: disable=unused-argument return (yield from self._make_request(FXP_OPEN)) with patch('asyncssh.sftp.SFTPClientHandler.open', _malformed_open): with self.assertRaises(SFTPError): yield from sftp.open('file') @sftp_test def test_unknown_request(self, sftp): """Test sending unknown request type""" @asyncio.coroutine def _unknown_request(self, filename, pflags, attrs): """Send a request with an unknown type""" # pylint: disable=unused-argument return (yield from self._make_request(0xff)) with patch('asyncssh.sftp.SFTPClientHandler.open', _unknown_request): with self.assertRaises(SFTPError): yield from sftp.open('file') @sftp_test def test_unrecognized_response_pktid(self, sftp): """Test sending a response with an unrecognized packet ID""" @asyncio.coroutine def _unrecognized_response_pktid(self, pkttype, pktid, packet): """Send a response with an unrecognized packet ID""" # pylint: disable=unused-argument self.send_packet(Byte(FXP_HANDLE), UInt32(0xffffffff), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _unrecognized_response_pktid): with self.assertRaises(SFTPError): yield from sftp.open('file') @sftp_test def test_bad_response_type(self, sftp): """Test sending a response with an incorrect response type""" @asyncio.coroutine def _bad_response_type(self, pkttype, pktid, packet): """Send a response with an incorrect response type""" # pylint: disable=unused-argument self.send_packet(Byte(FXP_DATA), UInt32(pktid), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _bad_response_type): with self.assertRaises(SFTPError): yield from sftp.open('file') @sftp_test def test_unexpected_ok_response(self, sftp): """Test sending an unexpected FX_OK response""" @asyncio.coroutine def _unexpected_ok_response(self, pkttype, pktid, packet): """Send an unexpected FX_OK response""" # pylint: disable=unused-argument self.send_packet(Byte(FXP_STATUS), UInt32(pktid), UInt32(FX_OK), String(''), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _unexpected_ok_response): with self.assertRaises(SFTPError): yield from sftp.open('file') @sftp_test def test_malformed_ok_response(self, sftp): """Test sending an FX_OK response containing invalid Unicode""" @asyncio.coroutine def _malformed_ok_response(self, pkttype, pktid, packet): """Send an FX_OK response containing invalid Unicode""" # pylint: disable=unused-argument self.send_packet(Byte(FXP_STATUS), UInt32(pktid), UInt32(FX_OK), String(b'\xff'), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _malformed_ok_response): with self.assertRaises(SFTPError): yield from sftp.open('file') @sftp_test def test_malformed_realpath_response(self, sftp): """Test receiving malformed realpath response""" @asyncio.coroutine def _malformed_realpath(self, path): """Return a malformed realpath response""" # pylint: disable=unused-argument return [SFTPName(''), SFTPName('')] with patch('asyncssh.sftp.SFTPClientHandler.realpath', _malformed_realpath): with self.assertRaises(SFTPError): yield from sftp.realpath('.') @sftp_test def test_malformed_readlink_response(self, sftp): """Test receiving malformed readlink response""" @asyncio.coroutine def _malformed_readlink(self, path): """Return a malformed readlink response""" # pylint: disable=unused-argument return [SFTPName(''), SFTPName('')] with patch('asyncssh.sftp.SFTPClientHandler.readlink', _malformed_readlink): with self.assertRaises(SFTPError): yield from sftp.readlink('.') def test_unsupported_extensions(self): """Test using extensions on a server that doesn't support them""" def _unsupported_extensions(self, sftp): """Try using unsupported extensions""" try: self._create_file('file1') with self.assertRaises(SFTPError): yield from sftp.statvfs('.') # pragma: no branch f = yield from sftp.open('file1') with self.assertRaises(SFTPError): yield from f.statvfs() # pragma: no branch with self.assertRaises(SFTPError): yield from sftp.posix_rename('file1', # pragma: no branch 'file2') with self.assertRaises(SFTPError): yield from sftp.link('file1', 'file2') # pragma: no branch with self.assertRaises(SFTPError): yield from f.fsync() finally: if f: # pragma: no branch yield from f.close() remove('file1') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): sftp_test(_unsupported_extensions)(self) def test_outstanding_nonblocking_close(self): """Test session cleanup with an outstanding non-blocking close""" @asyncio.coroutine def _nonblocking_close(self, sftp): """Initiate nonblocking close that triggers cleanup""" # pylint: disable=unused-argument try: with (yield from sftp.open('file', 'w')): pass finally: sftp.exit() yield from sftp.wait_closed() remove('file') with patch('asyncssh.sftp.SFTPServerHandler', _NonblockingCloseServerHandler): sftp_test(_nonblocking_close)(self) class _TestSFTPChroot(_CheckSFTP): """Unit test for SFTP server with changed root""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server with a changed root""" return (yield from cls.create_server(sftp_factory=_ChrootSFTPServer)) @sftp_test def test_chroot_copy(self, sftp): """Test copying a file to an FTP server with a changed root""" try: self._create_file('src') yield from sftp.put('src', 'dst') self._check_file('src', 'chroot/dst') finally: remove('src chroot/dst') @sftp_test def test_chroot_glob(self, sftp): """Test a glob pattern match over SFTP with a changed root""" try: self._create_file('chroot/file1') self._create_file('chroot/file2') self.assertEqual(sorted((yield from sftp.glob('/file*'))), ['/file1', '/file2']) finally: remove('chroot/file1 chroot/file2') @sftp_test def test_chroot_realpath(self, sftp): """Test canonicalizing a path on an SFTP server with a changed root""" self.assertEqual((yield from sftp.realpath('/dir/../file')), '/file') @sftp_test def test_getcwd_and_chdir(self, sftp): """Test changing directory on an SFTP server with a changed root""" try: os.mkdir('chroot/dir') self.assertEqual((yield from sftp.getcwd()), '/') yield from sftp.chdir('dir') self.assertEqual((yield from sftp.getcwd()), '/dir') finally: remove('chroot/dir') @sftp_test def test_chroot_readlink(self, sftp): """Test reading symlinks on an FTP server with a changed root""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: root = os.path.join(os.getcwd(), 'chroot') os.symlink(root, 'chroot/link1') os.symlink(os.path.join(root, 'file'), 'chroot/link2') os.symlink('/xxx', 'chroot/link3') self.assertEqual((yield from sftp.readlink('link1')), '/') self.assertEqual((yield from sftp.readlink('link2')), '/file') with self.assertRaises(SFTPError): yield from sftp.readlink('link3') finally: remove('chroot/link1 chroot/link2 chroot/link3') @sftp_test def test_chroot_symlink(self, sftp): """Test setting a symlink on an SFTP server with a changed root""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: yield from sftp.symlink('/file', 'link1') yield from sftp.symlink('../../file', 'link2') self._check_link('chroot/link1', os.path.abspath('chroot/file')) self._check_link('chroot/link2', 'file') finally: remove('chroot/link1 chroot/link2') class _TestSFTPIOError(_CheckSFTP): """Unit test for SFTP server returning file I/O error""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns file I/O errors""" return (yield from cls.create_server(sftp_factory=_IOErrorSFTPServer)) @sftp_test def test_put_error(self, sftp): """Test error when putting a file to an SFTP server""" for method in ('put', 'copy'): with self.subTest(method=method): try: self._create_file('src', 4*1024*1024*'\0') with self.assertRaises((FileNotFoundError, SFTPError)): yield from getattr(sftp, method)('src', 'dst') finally: remove('src dst') class _TestSFTPNotImplemented(_CheckSFTP): """Unit test for SFTP server returning not-implemented error""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns not-implemented errors""" return (yield from cls.create_server(sftp_factory=_NotImplSFTPServer)) @sftp_test def test_symlink_error(self, sftp): """Test error when creating a symbolic link on an SFTP server""" with self.assertRaises(SFTPError): yield from sftp.symlink('file', 'link') class _TestSFTPLongname(_CheckSFTP): """Unit test for SFTP server formatting directory listings""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns a fixed directory listing""" return (yield from cls.create_server(sftp_factory=_LongnameSFTPServer)) @sftp_test def test_longname(self, sftp): """Test long name formatting in SFTP opendir call""" for file in (yield from sftp.readdir('/')): self.assertEqual(file.longname[56:], file.filename) @sftp_test def test_glob_hidden(self, sftp): """Test a glob pattern match on hidden files""" self.assertEqual((yield from sftp.glob('/.*')), ['/.file']) @unittest.skipIf(sys.platform == 'win32', 'skip uid/gid tests on Windows') @sftp_test def test_getpwuid_error(self, sftp): """Test long name formatting where user name can't be resolved""" def getpwuid_error(uid): """Simulate not being able to resolve user name""" # pylint: disable=unused-argument raise KeyError with patch('pwd.getpwuid', getpwuid_error): result = yield from sftp.readdir('/') self.assertEqual(result[3].longname[16:24], ' ') self.assertEqual(result[4].longname[16:24], '0 ') @unittest.skipIf(sys.platform == 'win32', 'skip uid/gid tests on Windows') @sftp_test def test_getgrgid_error(self, sftp): """Test long name formatting where group name can't be resolved""" def getgrgid_error(gid): """Simulate not being able to resolve group name""" # pylint: disable=unused-argument raise KeyError with patch('grp.getgrgid', getgrgid_error): result = yield from sftp.readdir('/') self.assertEqual(result[3].longname[25:33], ' ') self.assertEqual(result[4].longname[25:33], '0 ') @sftp_test def test_strftime_error(self, sftp): """Test long name formatting with strftime not supporting %e""" orig_strftime = time.strftime def strftime_error(fmt, t): """Simulate Windows srtftime that doesn't support %e""" # pylint: disable=unused-argument if '%e' in fmt: raise ValueError else: return orig_strftime(fmt, t) with patch('time.strftime', strftime_error): result = yield from sftp.readdir('/') self.assertEqual(result[3].longname[51:55], ' ') self.assertIn(result[4].longname[51:55], ('1969', '1970')) class _TestSFTPLargeListDir(_CheckSFTP): """Unit test for SFTP server returning large listdir result""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns file I/O errors""" return (yield from cls.create_server(sftp_factory=_LargeDirSFTPServer)) @sftp_test def test_large_listdir(self, sftp): """Test large listdir result""" self.assertEqual(len((yield from sftp.readdir('/'))), 100000) @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') class _TestSFTPStatVFS(_CheckSFTP): """Unit test for SFTP server filesystem attributes""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns fixed filesystem attrs""" return (yield from cls.create_server(sftp_factory=_StatVFSSFTPServer)) def _check_statvfs(self, sftp_statvfs): """Check if filesystem attributes are equal""" expected_statvfs = _StatVFSSFTPServer.expected_statvfs self.assertEqual(sftp_statvfs.bsize, expected_statvfs.bsize) self.assertEqual(sftp_statvfs.frsize, expected_statvfs.frsize) self.assertEqual(sftp_statvfs.blocks, expected_statvfs.blocks) self.assertEqual(sftp_statvfs.bfree, expected_statvfs.bfree) self.assertEqual(sftp_statvfs.bavail, expected_statvfs.bavail) self.assertEqual(sftp_statvfs.files, expected_statvfs.files) self.assertEqual(sftp_statvfs.ffree, expected_statvfs.ffree) self.assertEqual(sftp_statvfs.favail, expected_statvfs.favail) self.assertEqual(sftp_statvfs.fsid, expected_statvfs.fsid) self.assertEqual(sftp_statvfs.flags, expected_statvfs.flags) self.assertEqual(sftp_statvfs.namemax, expected_statvfs.namemax) self.assertEqual(repr(sftp_statvfs), repr(expected_statvfs)) @sftp_test def test_statvfs(self, sftp): """Test getting attributes on a filesystem""" self._check_statvfs((yield from sftp.statvfs('.'))) @sftp_test def test_file_statvfs(self, sftp): """Test getting attributes on the filesystem containing an open file""" f = None try: self._create_file('file') f = yield from sftp.open('file') self._check_statvfs((yield from f.statvfs())) finally: if f: # pragma: no branch yield from f.close() remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chown tests on Windows') class _TestSFTPChown(_CheckSFTP): """Unit test for SFTP server file ownership""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which simulates file ownership changes""" return (yield from cls.create_server(sftp_factory=_ChownSFTPServer)) @sftp_test def test_chown(self, sftp): """Test changing ownership of a file""" try: self._create_file('file') yield from sftp.chown('file', 1, 2) attrs = yield from sftp.stat('file') self.assertEqual(attrs.uid, 1) self.assertEqual(attrs.gid, 2) finally: remove('file') class _TestSFTPAttrs(unittest.TestCase): """Unit test for SFTPAttrs object""" def test_attrs(self): """Test encoding and decoding of SFTP attributes""" for kwargs in ({'size': 1234}, {'uid': 1, 'gid': 2}, {'permissions': 0o7777}, {'atime': 1, 'mtime': 2}, {'extended': [(b'a1', b'v1'), (b'a2', b'v2')]}): attrs = SFTPAttrs(**kwargs) packet = SSHPacket(attrs.encode()) self.assertEqual(repr(SFTPAttrs.decode(packet)), repr(attrs)) def test_illegal_attrs(self): """Test decoding illegal SFTP attributes value""" with self.assertRaises(SFTPError): SFTPAttrs.decode(SSHPacket(UInt32(FILEXFER_ATTR_UNDEFINED))) class _TestSFTPNonstandardSymlink(_CheckSFTP): """Unit tests for SFTP server with non-standard symlink order""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server for the tests to use""" return (yield from cls.create_server(server_version='OpenSSH', sftp_factory=_SymlinkSFTPServer)) @asynctest def test_nonstandard_symlink_client(self): """Test creating a symlink with opposite argument order""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: with (yield from self.connect(client_version='OpenSSH')) as conn: with (yield from conn.start_sftp_client()) as sftp: yield from sftp.symlink('link', 'file') self._check_link('link', 'file') # pragma: no branch yield from sftp.wait_closed() # pragma: no branch yield from conn.wait_closed() finally: remove('file link') class _TestSFTPAsync(_TestSFTP): """Unit test for an async SFTPServer""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server with coroutine callbacks""" return (yield from cls.create_server(sftp_factory=_AsyncSFTPServer)) @sftp_test def test_async_realpath(self, sftp): """Test canonicalizing a path on an async SFTP server""" self.assertEqual((yield from sftp.realpath('dir/../file')), posixpath.join((yield from sftp.getcwd()), 'file')) class _CheckSCP(_CheckSFTP): """Utility functions for AsyncSSH SCP unit tests""" @classmethod @asyncio.coroutine def asyncSetUpClass(cls): """Set up SCP target host/port tuple""" yield from super().asyncSetUpClass() cls._scp_server = (cls._server_addr, cls._server_port) @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server with SCP enabled for the tests to use""" return (yield from cls.create_server(sftp_factory=True, allow_scp=True)) class _TestSCP(_CheckSCP): """Unit tests for AsyncSSH SCP client and server""" @asynctest def test_get(self): """Test getting a file over SCP""" try: self._create_file('src') yield from scp((self._scp_server, 'src'), 'dst') self._check_file('src', 'dst') finally: remove('src dst') @asynctest def test_get_bytes_path(self): """Test getting a file with a byte string path over SCP""" try: self._create_file('src') yield from scp((self._scp_server, b'src'), b'dst') self._check_file('src', 'dst') finally: remove('src dst') @asynctest def test_get_progress(self): """Test getting a file over SCP with progress reporting""" def _report_progress(srcpath, dstpath, bytes_copied, total_bytes): """Monitor progress of copy""" # pylint: disable=unused-argument reports.append(bytes_copied) reports = [] try: self._create_file('src', 100000*'a') yield from scp((self._scp_server, 'src'), 'dst', block_size=8192, progress_handler=_report_progress) self._check_file('src', 'dst') self.assertEqual(len(reports), 13) self.assertEqual(reports[-1], 100000) finally: remove('src dst') @asynctest def test_get_preserve(self): """Test getting a file with preserved attributes over SCP""" try: self._create_file('src', utime=(1, 2)) yield from scp((self._scp_server, 'src'), 'dst', preserve=True) self._check_file('src', 'dst', preserve=True, check_atime=False) finally: remove('src dst') @asynctest def test_get_recurse(self): """Test recursively getting a directory over SCP""" try: os.mkdir('src') self._create_file('src/file1') yield from scp((self._scp_server, 'src'), 'dst', recurse=True) self._check_file('src/file1', 'dst/file1') finally: remove('src dst') @asynctest def test_get_error_handler(self): """Test getting multiple files over SCP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" self.assertEqual(exc.reason, 'scp: Not a regular file: src2') try: self._create_file('src1') os.mkdir('src2') os.mkdir('dst') yield from scp((self._scp_server, 'src*'), 'dst', error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') @asynctest def test_get_recurse_existing(self): """Test getting a directory over SCP where target dir exists""" try: os.mkdir('src') os.mkdir('dst') os.mkdir('dst/src') self._create_file('src/file1') yield from scp((self._scp_server, 'src'), 'dst', recurse=True) self._check_file('src/file1', 'dst/src/file1') finally: remove('src dst') @unittest.skipIf(sys.platform == 'win32', 'skip permission tests on Windows') @asynctest def test_get_not_permitted(self): """Test getting a file with no read permissions over SCP""" try: self._create_file('src', mode=0) with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'src'), 'dst') finally: remove('src dst') @asynctest def test_get_directory_as_file(self): """Test getting a file which is actually a directory over SCP""" try: os.mkdir('src') with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'src'), 'dst') finally: remove('src dst') @asynctest def test_get_non_directory_in_path(self): """Test getting a file with a non-directory in path over SCP""" try: self._create_file('src') with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'src/xxx'), 'dst') finally: remove('src dst') @asynctest def test_get_recurse_not_directory(self): """Test getting a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'src'), 'dst', recurse=True) finally: remove('src dst') @asynctest def test_put(self): """Test putting a file over SCP""" try: self._create_file('src') yield from scp('src', (self._scp_server, 'dst')) self._check_file('src', 'dst') finally: remove('src dst') @asynctest def test_put_bytes_path(self): """Test putting a file with a byte string path over SCP""" try: self._create_file('src') yield from scp(b'src', (self._scp_server, b'dst')) self._check_file('src', 'dst') finally: remove('src dst') @asynctest def test_put_progress(self): """Test putting a file over SCP with progress reporting""" def _report_progress(srcpath, dstpath, bytes_copied, total_bytes): """Monitor progress of copy""" # pylint: disable=unused-argument reports.append(bytes_copied) reports = [] try: self._create_file('src', 100000*'a') yield from scp('src', (self._scp_server, 'dst'), block_size=8192, progress_handler=_report_progress) self._check_file('src', 'dst') self.assertEqual(len(reports), 13) self.assertEqual(reports[-1], 100000) finally: remove('src dst') @asynctest def test_put_preserve(self): """Test putting a file with preserved attributes over SCP""" try: self._create_file('src', utime=(1, 2)) yield from scp('src', (self._scp_server, 'dst'), preserve=True) self._check_file('src', 'dst', preserve=True, check_atime=False) finally: remove('src dst') @asynctest def test_put_recurse(self): """Test recursively putting a directory over SCP""" try: os.mkdir('src') self._create_file('src/file1') yield from scp('src', (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/file1') finally: remove('src dst') @asynctest def test_put_recurse_existing(self): """Test putting a directory over SCP where target dir exists""" try: os.mkdir('src') os.mkdir('dst') self._create_file('src/file1') yield from scp('src', (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/src/file1') finally: remove('src dst') @asynctest def test_put_must_be_dir(self): """Test putting multiple files to a non-directory over SCP""" try: self._create_file('src1') self._create_file('src2') self._create_file('dst') with self.assertRaises(SFTPError): yield from scp(['src1', 'src2'], (self._scp_server, 'dst')) finally: remove('src1 src2 dst') @asynctest def test_put_non_directory_in_path(self): """Test putting a file with a non-directory in path over SCP""" try: self._create_file('src') with self.assertRaises(OSError): yield from scp('src/xxx', (self._scp_server, 'dst')) finally: remove('src') @asynctest def test_put_recurse_not_directory(self): """Test putting a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 'dst'), recurse=True) finally: remove('src dst') @asynctest def test_put_read_error(self): """Test read errors when putting a file over SCP""" @asyncio.coroutine def _read_error(self, size, offset): """Return an error for reads past 64 KB in a file""" if offset >= 65536: raise OSError(errno.EIO, 'I/O error') else: return (yield from orig_read(self, size, offset)) try: self._create_file('src', 128*1024*'\0') orig_read = LocalFile.read with patch('asyncssh.sftp.LocalFile.read', _read_error): with self.assertRaises(OSError): yield from scp('src', (self._scp_server, 'dst')) finally: remove('src dst') @asynctest def test_put_read_early_eof(self): """Test getting early EOF when putting a file over SCP""" @asyncio.coroutine def _read_early_eof(self, size, offset): """Return an early EOF for reads past 64 KB in a file""" if offset >= 65536: return b'' else: return (yield from orig_read(self, size, offset)) try: self._create_file('src', 128*1024*'\0') orig_read = LocalFile.read with patch('asyncssh.sftp.LocalFile.read', _read_early_eof): with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 'dst')) finally: remove('src dst') @asynctest def test_put_name_too_long(self): """Test putting a file over SCP with too long a name""" try: self._create_file('src') with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 65536*'a')) finally: remove('src dst') @asynctest def test_copy(self): """Test copying a file between remote hosts over SCP""" try: self._create_file('src') yield from scp((self._scp_server, 'src'), (self._scp_server, 'dst')) self._check_file('src', 'dst') finally: remove('src dst') @asynctest def test_copy_progress(self): """Test copying a file over SCP with progress reporting""" def _report_progress(srcpath, dstpath, bytes_copied, total_bytes): """Monitor progress of copy""" # pylint: disable=unused-argument reports.append(bytes_copied) reports = [] try: self._create_file('src', 100000*'a') yield from scp((self._scp_server, 'src'), (self._scp_server, 'dst'), block_size=8192, progress_handler=_report_progress) self._check_file('src', 'dst') self.assertEqual(len(reports), 13) self.assertEqual(reports[-1], 100000) finally: remove('src dst') @asynctest def test_copy_preserve(self): """Test copying a file with preserved attributes between hosts""" try: self._create_file('src', utime=(1, 2)) yield from scp((self._scp_server, 'src'), (self._scp_server, 'dst'), preserve=True) self._check_file('src', 'dst', preserve=True, check_atime=False) finally: remove('src dst') @asynctest def test_copy_recurse(self): """Test recursively copying a directory between hosts over SCP""" try: os.mkdir('src') self._create_file('src/file1') yield from scp((self._scp_server, 'src'), (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/file1') finally: remove('src dst') @asynctest def test_copy_error_handler_source(self): """Test copying multiple files over SCP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" self.assertEqual(exc.reason, 'scp: Not a regular file: src2') try: self._create_file('src1') os.mkdir('src2') os.mkdir('dst') yield from scp((self._scp_server, 'src*'), (self._scp_server, 'dst'), error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') @asynctest def test_copy_error_handler_sink(self): """Test copying multiple files over SCP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" if sys.platform == 'win32': # pragma: no cover self.assertEqual(exc.reason, 'scp: Permission denied: dst\\src2') else: self.assertEqual(exc.reason, 'scp: Is a directory: dst/src2') try: self._create_file('src1') self._create_file('src2') os.mkdir('dst') os.mkdir('dst/src2') yield from scp((self._scp_server, 'src*'), (self._scp_server, 'dst'), error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') @asynctest def test_copy_recurse_existing(self): """Test copying a directory over SCP where target dir exists""" try: os.mkdir('src') os.mkdir('dst') self._create_file('src/file1') yield from scp((self._scp_server, 'src'), (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/src/file1') finally: remove('src dst') @asynctest def test_local_copy(self): """Test for error return when attempting to copy local files""" with self.assertRaises(ValueError): yield from scp('src', 'dst') @asynctest def test_copy_multiple(self): """Test copying multiple files over SCP""" try: os.mkdir('src') self._create_file('src/file1') self._create_file('src/file2') yield from scp([(self._scp_server, 'src/file1'), (self._scp_server, 'src/file2')], '.') self._check_file('src/file1', 'file1') self._check_file('src/file2', 'file2') finally: remove('src file1 file2') @asynctest def test_copy_recurse_not_directory(self): """Test copying a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'src'), (self._scp_server, 'dst'), recurse=True) finally: remove('src dst') @asynctest def test_source_string(self): """Test passing a string to SCP""" with self.assertRaises(OSError): yield from scp('0.0.0.1:xxx', '.') @unittest.skipUnless(python35, 'skip host as bytes before Python 3.5') @asynctest def test_source_bytes(self): """Test passing a byte string to SCP""" with self.assertRaises(OSError): yield from scp(b'0.0.0.1:xxx', '.') @asynctest def test_source_open_connection(self): """Test passing an open SSHClientConnection to SCP as source""" try: with (yield from self.connect()) as conn: self._create_file('src') yield from scp((conn, 'src'), 'dst') self._check_file('src', 'dst') finally: remove('src dst') @asynctest def test_destination_open_connection(self): """Test passing an open SSHClientConnection to SCP as destination""" try: with (yield from self.connect()) as conn: os.mkdir('src') self._create_file('src/file1') yield from scp('src/file1', conn) self._check_file('src/file1', 'file1') finally: remove('src file1') @asynctest def test_missing_path(self): """Test running SCP with missing path""" with (yield from self.connect()) as conn: result = yield from conn.run('scp ') self.assertEqual(result.stderr, 'scp: the following arguments ' 'are required: path\n') @asynctest def test_missing_direction(self): """Test running SCP with missing direction argument""" with (yield from self.connect()) as conn: result = yield from conn.run('scp xxx') self.assertEqual(result.stderr, 'scp: one of the arguments -f -t ' 'is required\n') @asynctest def test_invalid_argument(self): """Test running SCP with invalid argument""" with (yield from self.connect()) as conn: result = yield from conn.run('scp -f -x src') self.assertEqual(result.stderr, 'scp: unrecognized arguments: -x\n') @asynctest def test_invalid_c_argument(self): """Test running SCP with invalid argument to C request""" with (yield from self.connect()) as conn: result = yield from conn.run('scp -t dst', input='C\n') self.assertEqual(result.stdout, '\0\x01scp: Invalid copy or dir request\n') @asynctest def test_invalid_t_argument(self): """Test running SCP with invalid argument to C request""" with (yield from self.connect()) as conn: result = yield from conn.run('scp -t -p dst', input='T\n') self.assertEqual(result.stdout, '\0\x01scp: Invalid time request\n') class _TestSCPAsync(_TestSCP): """Unit test for AsyncSSH SCP using an async SFTPServer""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server with coroutine callbacks""" return (yield from cls.create_server(sftp_factory=_AsyncSFTPServer, allow_scp=True)) class _TestSCPAttrs(_CheckSCP): """Unit test for SCP with SFTP server returning SFTPAttrs""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns SFTPAttrs from stat""" return (yield from cls.create_server(sftp_factory=_SFTPAttrsSFTPServer, allow_scp=True)) @asynctest def test_get(self): """Test getting a file over SCP with stat returning SFTPAttrs""" try: self._create_file('src') yield from scp((self._scp_server, 'src'), 'dst') self._check_file('src', 'dst') finally: remove('src dst') @asynctest def test_put_recurse_not_directory(self): """Test putting a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 'dst'), recurse=True) finally: remove('src dst') @asynctest def test_put_not_permitted(self): """Test putting a file over SCP onto an unwritable target""" try: self._create_file('src') os.mkdir('dst') os.chmod('dst', 0) with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 'dst/src')) finally: os.chmod('dst', 0o755) remove('src dst') @asynctest def test_put_name_too_long(self): """Test putting a file over SCP with too long a name""" try: self._create_file('src') with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 65536*'a')) finally: remove('src dst') class _TestSCPIOError(_CheckSCP): """Unit test for SCP with SFTP server returning file I/O error""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns file I/O errors""" return (yield from cls.create_server(sftp_factory=_IOErrorSFTPServer, allow_scp=True)) @asynctest def test_put_error(self): """Test error when putting a file over SCP""" try: self._create_file('src', 4*1024*1024*'\0') with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 'dst')) finally: remove('src dst') @asynctest def test_copy_error(self): """Test error when copying a file over SCP""" try: self._create_file('src', 4*1024*1024*'\0') with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'src'), (self._scp_server, 'dst')) finally: remove('src dst') class _TestSCPErrors(_CheckSCP): """Unit test for SCP returning error on startup""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SFTP server which returns file I/O errors""" @asyncio.coroutine def _handle_client(process): """Handle new client""" with process: command = process.command if command.endswith('get_connection_lost'): pass elif command.endswith('get_dir_no_recurse'): yield from process.stdin.read(1) process.stdout.write('D0755 0 src\n') elif command.endswith('get_early_eof'): yield from process.stdin.read(1) process.stdout.write('C0644 10 src\n') yield from process.stdin.read(1) elif command.endswith('get_extra_e'): yield from process.stdin.read(1) process.stdout.write('E\n') yield from process.stdin.read(1) elif command.endswith('get_t_without_preserve'): yield from process.stdin.read(1) process.stdout.write('T0 0 0 0\n') yield from process.stdin.read(1) elif command.endswith('get_unknown_action'): yield from process.stdin.read(1) process.stdout.write('X\n') yield from process.stdin.read(1) elif command.endswith('put_connection_lost'): process.stdout.write('\0\0') elif command.endswith('put_startup_error'): process.stdout.write('Error starting SCP\n') elif command.endswith('recv_early_eof'): process.stdout.write('\0') yield from process.stdin.readline() try: process.stdout.write('\0') except BrokenPipeError: pass else: process.exit(255) return (yield from cls.create_server(process_factory=_handle_client)) @asynctest def test_get_directory_without_recurse(self): """Test receiving directory when recurse wasn't requested""" try: with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'get_dir_no_recurse'), 'dst') finally: remove('dst') @asynctest def test_get_early_eof(self): """Test getting early EOF when getting a file over SCP""" try: with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'get_early_eof'), 'dst') finally: remove('dst') @asynctest def test_get_t_without_preserve(self): """Test getting timestamps with requesting preserve""" try: yield from scp((self._scp_server, 'get_t_without_preserve'), 'dst') finally: remove('dst') @asynctest def test_get_unknown_action(self): """Test getting unknown action from SCP server during get""" try: with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'get_unknown_action'), 'dst') finally: remove('dst') @asynctest def test_put_startup_error(self): """Test SCP server returning an error on startup""" try: self._create_file('src') with self.assertRaises(SFTPError) as exc: yield from scp('src', (self._scp_server, 'put_startup_error')) self.assertEqual(exc.exception.reason, 'Error starting SCP') finally: remove('src') @asynctest def test_put_connection_lost(self): """Test SCP server abruptly closing connection on put""" try: self._create_file('src') with self.assertRaises(SFTPError) as exc: yield from scp('src', (self._scp_server, 'put_connection_lost')) self.assertEqual(exc.exception.reason, 'Connection lost') finally: remove('src') @asynctest def test_copy_connection_lost_source(self): """Test source abruptly closing connection during SCP copy""" with self.assertRaises(SFTPError) as exc: yield from scp((self._scp_server, 'get_connection_lost'), (self._scp_server, 'recv_early_eof')) self.assertEqual(exc.exception.reason, 'Connection lost') @asynctest def test_copy_connection_lost_sink(self): """Test sink abruptly closing connection during SCP copy""" with self.assertRaises(SFTPError) as exc: yield from scp((self._scp_server, 'get_early_eof'), (self._scp_server, 'put_connection_lost')) self.assertEqual(exc.exception.reason, 'Connection lost') @asynctest def test_copy_early_eof(self): """Test getting early EOF when copying a file over SCP""" with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'get_early_eof'), (self._scp_server, 'recv_early_eof')) @asynctest def test_copy_extra_e(self): """Test getting extra E when copying a file over SCP""" yield from scp((self._scp_server, 'get_extra_e'), (self._scp_server, 'recv_early_eof')) @asynctest def test_copy_unknown_action(self): """Test getting unknown action from SCP server during copy""" with self.assertRaises(SFTPError): yield from scp((self._scp_server, 'get_unknown_action'), (self._scp_server, 'recv_early_eof')) @asynctest def test_unknown(self): """Test unknown SCP server request for code coverage""" with self.assertRaises(SFTPError): yield from scp('src', (self._scp_server, 'unknown')) asyncssh-1.11.1/tests/test_stream.py000066400000000000000000000231721320320510200174650ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH stream API""" import asyncio import asyncssh from .server import Server, ServerTestCase from .util import asynctest, echo class _StreamServer(Server): """Server for testing the AsyncSSH stream API""" def _begin_session(self, stdin, stdout, stderr): """Begin processing a new session""" # pylint: disable=no-self-use action = stdin.channel.get_command() if not action: action = 'echo' if action == 'echo': yield from echo(stdin, stdout) elif action == 'echo_stderr': yield from echo(stdin, stdout, stderr) elif action == 'close': yield from stdin.read(1) stdout.write('\n') elif action == 'disconnect': stdout.write((yield from stdin.read(1))) raise asyncssh.DisconnectError(asyncssh.DISC_CONNECTION_LOST, 'Connection lost') else: stdin.channel.exit(255) stdin.channel.close() yield from stdin.channel.wait_closed() def _begin_session_non_async(self, stdin, stdout, stderr): """Non-async version of session handler""" self._conn.create_task(self._begin_session(stdin, stdout, stderr)) def begin_auth(self, username): """Handle client authentication request""" return False def session_requested(self): """Handle a request to create a new session""" username = self._conn.get_extra_info('username') if username == 'non_async': return self._begin_session_non_async elif username != 'no_channels': return self._begin_session else: return False class _TestStream(ServerTestCase): """Unit tests for AsyncSSH stream API""" @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server(_StreamServer)) @asyncio.coroutine def _check_session(self, conn, large_block=False): """Open a session and test if an input line is echoed back""" stdin, stdout, stderr = yield from conn.open_session('echo_stderr') if large_block: data = 4 * [1025*1024*'\0'] else: data = [str(id(self))] stdin.writelines(data) yield from stdin.drain() self.assertTrue(stdin.can_write_eof()) stdin.write_eof() stdout_data, stderr_data = yield from asyncio.gather(stdout.read(), stderr.read()) data = ''.join(data) self.assertEqual(data, stdout_data) self.assertEqual(data, stderr_data) yield from stdin.channel.wait_closed() yield from stdin.drain() stdin.close() @asynctest def test_shell(self): """Test starting a shell""" with (yield from self.connect()) as conn: yield from self._check_session(conn) yield from conn.wait_closed() @asynctest def test_shell_failure(self): """Test failure to start a shell""" with (yield from self.connect(username='no_channels')) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from conn.open_session() yield from conn.wait_closed() @asynctest def test_shell_non_async(self): """Test starting a shell using non-async handler""" with (yield from self.connect(username='non_async')) as conn: yield from self._check_session(conn) yield from conn.wait_closed() @asynctest def test_large_block(self): """Test sending and receiving a large block of data""" with (yield from self.connect()) as conn: yield from self._check_session(conn, large_block=True) yield from conn.wait_closed() @asynctest def test_write_broken_pipe(self): """Test close while we're writing""" with (yield from self.connect()) as conn: stdin, _, _ = yield from conn.open_session('close') stdin.write(4*1024*1024*'\0') with self.assertRaises((ConnectionError, asyncssh.DisconnectError)): yield from stdin.drain() yield from conn.wait_closed() @asynctest def test_write_disconnect(self): """Test disconnect while we're writing""" with (yield from self.connect()) as conn: stdin, _, _ = yield from conn.open_session('disconnect') stdin.write(6*1024*1024*'\0') with self.assertRaises((ConnectionError, asyncssh.DisconnectError)): yield from stdin.drain() yield from conn.wait_closed() @asynctest def test_multiple_read(self): """Test calling blocking read multiple times""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session() done, _ = yield from asyncio.wait( [stdout.read(), stdout.read()], return_when=asyncio.FIRST_EXCEPTION) with self.assertRaises(RuntimeError): yield from done stdin.close() yield from conn.wait_closed() @asynctest def test_read_exception(self): """Test read returning an exception""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session('disconnect') stdin.write('\0') self.assertEqual((yield from stdout.read()), '\0') with self.assertRaises(asyncssh.DisconnectError): yield from stdout.read(1) stdin.close() yield from conn.wait_closed() @asynctest def test_readline_exception(self): """Test readline returning an exception""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session('disconnect') stdin.write('\0') self.assertEqual((yield from stdout.readline()), '\0') with self.assertRaises(asyncssh.DisconnectError): yield from stdout.readline() stdin.close() yield from conn.wait_closed() @asynctest def test_pause_read(self): """Test pause reading""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session() stdin.write(6*1024*1024*'\0') yield from asyncio.sleep(0.01) yield from stdout.read(1) yield from asyncio.sleep(0.01) yield from stdout.read(1) stdin.channel.abort() yield from conn.wait_closed() @asynctest def test_pause_readline(self): """Test pause reading while calling readline""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session() stdin.write('\n'+2*1024*1024*'\0') stdin.write_eof() yield from asyncio.sleep(0.01) yield from stdout.readline() yield from asyncio.sleep(0.01) yield from stdout.readline() stdin.channel.abort() yield from conn.wait_closed() @asynctest def test_readuntil(self): """Test readuntil with multi-character separator""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session() stdin.write('abc\r') yield from asyncio.sleep(0.01) stdin.write('\ndef') yield from asyncio.sleep(0.01) stdin.write('\r\n') yield from asyncio.sleep(0.01) stdin.write('ghi') stdin.write_eof() self.assertEqual((yield from stdout.readuntil('\r\n')), 'abc\r\n') self.assertEqual((yield from stdout.readuntil('\r\n')), 'def\r\n') with self.assertRaises(asyncio.IncompleteReadError) as exc: yield from stdout.readuntil('\r\n') self.assertEqual(exc.exception.partial, 'ghi') stdin.close() yield from conn.wait_closed() @asynctest def test_readuntil_empty_separator(self): """Test readuntil with empty separator""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session() with self.assertRaises(ValueError): yield from stdout.readuntil('') stdin.close() yield from conn.wait_closed() @asynctest def test_get_extra_info(self): """Test get_extra_info on streams""" with (yield from self.connect()) as conn: stdin, stdout, _ = yield from conn.open_session() self.assertEqual(stdin.get_extra_info('connection'), stdout.get_extra_info('connection')) stdin.close() yield from conn.wait_closed() @asynctest def test_unknown_action(self): """Test unknown action""" with (yield from self.connect()) as conn: stdin, _, _ = yield from conn.open_session('unknown') yield from stdin.channel.wait_closed() self.assertEqual(stdin.channel.get_exit_status(), 255) yield from conn.wait_closed() asyncssh-1.11.1/tests/test_x11.py000066400000000000000000000505261320320510200166060ustar00rootroot00000000000000# Copyright (c) 2016-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH X11 forwarding""" import asyncio import os import socket from unittest.mock import patch import asyncssh from asyncssh.packet import Boolean, String, UInt32 from asyncssh.x11 import XAUTH_FAMILY_IPV4, XAUTH_FAMILY_DECNET from asyncssh.x11 import XAUTH_FAMILY_IPV6, XAUTH_FAMILY_HOSTNAME from asyncssh.x11 import XAUTH_FAMILY_WILD, XAUTH_PROTO_COOKIE from asyncssh.x11 import XAUTH_COOKIE_LEN, X11_BASE_PORT, X11_LISTEN_HOST from asyncssh.x11 import SSHXAuthorityEntry, SSHX11ClientListener from asyncssh.x11 import walk_xauth, lookup_xauth, update_xauth from .server import Server, ServerTestCase from .util import asynctest def _failing_bind(self, address): """Raise OSError to simulate a socket bind failure""" # pylint: disable=unused-argument raise OSError @asyncio.coroutine def _create_x11_process(conn, command=None, x11_display='test:0', **kwargs): """Create a client process with X11 forwarding enabled""" return (yield from conn.create_process(command, x11_forwarding=True, x11_display=x11_display, **kwargs)) class _X11Peer: """Peer representing X server to forward connections to""" expected_auth = b'' @classmethod @asyncio.coroutine def connect(cls, session_factory, host, port): """Simulate connecting to an X server""" # pylint: disable=unused-argument if port == X11_BASE_PORT: return None, cls() else: raise OSError('Connection refused') def __init__(self): self._peer = None self._check_auth = True def set_peer(self, peer): """Set the peer forwarder to exchange data with""" self._peer = peer def write(self, data): """Consume data from the peer""" if self._check_auth: match = data[32:48] == self.expected_auth self._peer.write(b'\x01' if match else b'\xff') self._check_auth = False else: self._peer.write(data) def write_eof(self): """Consume EOF from the peer""" pass # pragma: no cover def was_eof_received(self): """Report that an incoming EOF has not been reeceived""" # pylint: disable=no-self-use return False # pragma: no cover def pause_reading(self): """Ignore flow control requests""" # pylint: disable=no-self-use pass # pragma: no cover def resume_reading(self): """Ignore flow control requests""" # pylint: disable=no-self-use pass # pragma: no cover def close(self): """Consume close request""" pass # pragma: no cover class _X11ClientListener(SSHX11ClientListener): """Unit test X server to forward connections to""" @asyncio.coroutine def forward_connection(self): """Forward a connection to this server""" self._connect_coro = _X11Peer.connect return super().forward_connection() class _X11ClientChannel(asyncssh.SSHClientChannel): """Patched X11 client channel for unit testing""" @asyncio.coroutine def make_x11_forwarding_request(self, proto, data, screen): """Make a request to enable X11 forwarding""" return (yield from self._make_request(b'x11-req', Boolean(False), String(proto), String(data), UInt32(screen))) class _X11ServerConnection(asyncssh.SSHServerConnection): """Unit test X11 forwarding server connection""" @asyncio.coroutine def attach_x11_listener(self, chan, auth_proto, auth_data, screen): """Attach a channel to a remote X11 display""" if screen == 9: return False _X11Server.auth_proto = auth_proto _X11Server.auth_data = auth_data return (yield from super().attach_x11_listener(chan, auth_proto, auth_data, screen)) class _X11Server(Server): """Server for testing AsyncSSH X11 forwarding""" auth_proto = b'' auth_data = b'' @staticmethod def _uint16(value, endian): """Encode a 16-bit value using the specified endianness""" if endian == 'B': return bytes((value >> 8, value & 255)) else: return bytes((value & 255, value >> 8)) @staticmethod def _pad(data): """Pad a string to a multiple of 4 bytes""" length = len(data) % 4 return data + ((4 - length) * b'\00' if length else b'') def _open_x11(self, chan, endian, bad): """Open an X11 connection back to the client""" display = chan.get_x11_display() if display: dpynum = int(display.rsplit(':')[-1].split('.')[0]) else: return 2 reader, writer = yield from asyncio.open_connection( X11_LISTEN_HOST, X11_BASE_PORT + dpynum) auth_data = bytearray(self.auth_data) if bad: auth_data[-1] ^= 0xff request = b''.join((endian.encode('ascii'), b'\x00', self._uint16(11, endian), self._uint16(0, endian), self._uint16(len(self.auth_proto), endian), self._uint16(len(self.auth_data), endian), b'\x00\x00', self._pad(self.auth_proto), self._pad(auth_data))) writer.write(request[:24]) yield from asyncio.sleep(0.1) writer.write(request[24:]) result = yield from reader.read(1) if result == b'': result = b'\x02' if result == b'\x01': writer.write(b'\x00') writer.close() return result[0] @asyncio.coroutine def _begin_session(self, stdin, stdout, stderr): """Begin processing a new session""" # pylint: disable=unused-argument action = stdin.channel.get_command() if action: if action.startswith('connect '): endian = action[8:9] bad = bool(action[9:] == 'X') result = yield from self._open_x11(stdin.channel, endian, bad) stdin.channel.exit(result) elif action == 'attach': with patch('socket.socket.bind', _failing_bind): result = yield from self._conn.attach_x11_listener( None, b'', b'', 0) stdin.channel.exit(bool(result)) elif action == 'open': try: result = yield from self._conn.create_x11_connection(None) except asyncssh.ChannelOpenError: result = None stdin.channel.exit(bool(result)) elif action == 'sleep': yield from asyncio.sleep(0.1) else: stdin.channel.exit(255) stdin.channel.close() yield from stdin.channel.wait_closed() def session_requested(self): return self._begin_session @patch('asyncssh.connection.SSHServerConnection', _X11ServerConnection) @patch('asyncssh.x11.SSHX11ClientListener', _X11ClientListener) class _TestX11(ServerTestCase): """Unit tests for AsyncSSH X11 forwarding""" @classmethod def setUpClass(cls): """Create Xauthority file needed for test""" super().setUpClass() auth_data = os.urandom(XAUTH_COOKIE_LEN) with open('.Xauthority', 'wb') as auth_file: auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_HOSTNAME, b'test', b'1', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_HOSTNAME, b'test', b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV4, socket.inet_pton(socket.AF_INET, '0.0.0.2'), b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV4, socket.inet_pton(socket.AF_INET, '0.0.0.1'), b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV6, socket.inet_pton(socket.AF_INET6, '::2'), b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV6, socket.inet_pton(socket.AF_INET6, '::1'), b'0', XAUTH_PROTO_COOKIE, auth_data))) # Added to cover case where we don't match on address family auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_DECNET, b'test', b'0', XAUTH_PROTO_COOKIE, auth_data))) # Wildcard address family match auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_WILD, b'', b'0', XAUTH_PROTO_COOKIE, auth_data))) with open('.Xauthority-empty', 'wb'): pass with open('.Xauthority-corrupted', 'wb') as auth_file: auth_file.write(b'\x00\x00\x00') _X11Peer.expected_auth = auth_data @classmethod @asyncio.coroutine def start_server(cls): """Start an SSH server for the tests to use""" return (yield from cls.create_server( _X11Server, x11_forwarding=True, authorized_client_keys='authorized_keys')) @asyncio.coroutine def _check_x11(self, command=None, *, exc=None, exit_status=None, **kwargs): """Check requesting X11 forwarding""" with (yield from self.connect()) as conn: if exc: with self.assertRaises(exc): yield from _create_x11_process(conn, command, **kwargs) else: proc = yield from _create_x11_process(conn, command, **kwargs) yield from proc.wait() self.assertEqual(proc.exit_status, exit_status) yield from conn.wait_closed() @asynctest def test_xauth_lookup(self): """Test writing an xauth entry and looking it back up""" yield from update_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0', b'', b'\x00') _, auth_data = yield from lookup_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0') os.unlink('xauth') self.assertEqual(auth_data, b'\x00') @asynctest def test_xauth_dead_lock(self): """Test removal of dead Xauthority lock""" with open('xauth-c', 'w'): pass yield from asyncio.sleep(6) yield from update_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0', b'', b'\x00') _, auth_data = yield from lookup_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0') os.unlink('xauth') self.assertEqual(auth_data, b'\x00') @asynctest def test_xauth_update(self): """Test overwriting an xauth entry""" yield from update_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0', b'', b'\x00') yield from update_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0', b'', b'\x01') self.assertEqual(len(list(walk_xauth('xauth'))), 1) _, auth_data = yield from lookup_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0') os.unlink('xauth') self.assertEqual(auth_data, b'\x01') @asynctest def test_forward_big(self): """Test requesting X11 forwarding with big-endian connect""" yield from self._check_x11('connect B', exit_status=1, x11_display='test:0.0', x11_single_connection=True) @asynctest def test_forward_little(self): """Test requesting X11 forwarding with little-endian connect""" yield from self._check_x11('connect l', exit_status=1) @asynctest def test_connection_refused_big(self): """Test the X server refusing connection with big-endian connect""" yield from self._check_x11('connect B', exit_status=2, x11_display='test:1') @asynctest def test_connection_refused_little(self): """Test the X server refusing connection with little-endian connect""" yield from self._check_x11('connect l', exit_status=2, x11_display='test:1') @asynctest def test_bad_auth_big(self): """Test sending bad auth data with big-endian connect""" yield from self._check_x11('connect BX', exit_status=0) @asynctest def test_bad_auth_little(self): """Test sending bad auth data with little-endian connect""" yield from self._check_x11('connect lX', exit_status=0) @asynctest def test_ipv4_address(self): """Test matching against an IPv4 address""" yield from self._check_x11(x11_display='0.0.0.1:0') @asynctest def test_ipv6_address(self): """Test matching against an IPv6 address""" yield from self._check_x11(x11_display='[::1]:0') @asynctest def test_wildcard_address(self): """Test matching against a wildcard host entry""" yield from self._check_x11(x11_display='wild:0') @asynctest def test_local_server(self): """Test matching against a local X server""" yield from self._check_x11(x11_display=':0') @asynctest def test_domain_socket(self): """Test matching against an explicit domain socket""" yield from self._check_x11(x11_display='/test:0') @asynctest def test_display_environment(self): """Test getting X11 display from the environment""" os.environ['DISPLAY'] = 'test:0' yield from self._check_x11(x11_display=None) del os.environ['DISPLAY'] @asynctest def test_display_not_set(self): """Test requesting X11 forwarding with no display set""" yield from self._check_x11(exc=asyncssh.ChannelOpenError, x11_display=None) @asynctest def test_forwarding_denied(self): """Test SSH server denying X11 forwarding""" yield from self._check_x11(exc=asyncssh.ChannelOpenError, x11_display='test:0.9') @asynctest def test_xauth_environment(self): """Test getting Xauthority path from the environment""" os.environ['XAUTHORITY'] = '.Xauthority' yield from self._check_x11() del os.environ['XAUTHORITY'] @asynctest def test_no_xauth_match(self): """Test no xauth match""" yield from self._check_x11(x11_display='no_match:1') @asynctest def test_invalid_display(self): """Test invalid X11 display value""" yield from self._check_x11(exc=asyncssh.ChannelOpenError, x11_display='test') @asynctest def test_xauth_missing(self): """Test missing .Xauthority file""" yield from self._check_x11(x11_auth_path='.Xauthority-missing') @asynctest def test_xauth_empty(self): """Test empty .Xauthority file""" yield from self._check_x11(x11_auth_path='.Xauthority-empty') @asynctest def test_xauth_corrupted(self): """Test .Xauthority file with corrupted entry""" yield from self._check_x11(exc=asyncssh.ChannelOpenError, x11_auth_path='.Xauthority-corrupted') @asynctest def test_selective_forwarding(self): """Test requesting X11 forwarding from one session and not another""" with (yield from self.connect()) as conn: yield from conn.create_process('sleep') yield from _create_x11_process(conn, 'sleep', x11_display='test:0') yield from conn.wait_closed() @asynctest def test_multiple_sessions(self): """Test requesting X11 forwarding from two different sessions""" with (yield from self.connect()) as conn: yield from _create_x11_process(conn) yield from _create_x11_process(conn) yield from conn.wait_closed() @asynctest def test_simultaneous_sessions(self): """Test X11 forwarding from multiple sessions simultaneously""" with (yield from self.connect()) as conn: yield from _create_x11_process(conn, 'sleep') yield from _create_x11_process(conn, 'sleep', x11_display='test:0.1') yield from conn.wait_closed() @asynctest def test_consecutive_different_servers(self): """Test X11 forwarding to different X servers consecutively""" with (yield from self.connect()) as conn: proc = yield from _create_x11_process(conn) yield from proc.wait() yield from _create_x11_process(conn, x11_display='test1:0') yield from conn.wait_closed() @asynctest def test_simultaneous_different_servers(self): """Test X11 forwarding to different X servers simultaneously""" with (yield from self.connect()) as conn: yield from _create_x11_process(conn, 'sleep') with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_x11_process(conn, x11_display='test1:0') yield from conn.wait_closed() @asynctest def test_forwarding_disabled(self): """Test X11 request when forwarding was never enabled""" with (yield from self.connect()) as conn: result = yield from conn.run('connect l') self.assertEqual(result.exit_status, 2) yield from conn.wait_closed() @asynctest def test_attach_failure(self): """Test X11 listener attach when forwarding was never enabled""" with (yield from self.connect()) as conn: result = yield from conn.run('attach') self.assertEqual(result.exit_status, 0) yield from conn.wait_closed() @asynctest def test_attach_lock_failure(self): """Test X11 listener attach when Xauthority can't be locked""" with open('.Xauthority-c', 'w'): pass yield from self._check_x11('connect l', exc=asyncssh.ChannelOpenError) os.unlink('.Xauthority-c') @asynctest def test_open_failure(self): """Test opening X11 connection when forwarding was never enabled""" with (yield from self.connect()) as conn: result = yield from conn.run('open') self.assertEqual(result.exit_status, 0) yield from conn.wait_closed() @asynctest def test_forwarding_not_allowed(self): """Test an X11 request from a non-authorized user""" ckey = asyncssh.read_private_key('ckey') cert = ckey.generate_user_certificate(ckey, 'name', principals=['ckey'], permit_x11_forwarding=False) with (yield from self.connect(username='ckey', client_keys=[(ckey, cert)])) as conn: with self.assertRaises(asyncssh.ChannelOpenError): yield from _create_x11_process(conn, 'connect l') yield from conn.wait_closed() @asynctest def test_invalid_x11_forwarding_request(self): """Test an invalid X11 forwarding request""" with patch('asyncssh.connection.SSHClientChannel', _X11ClientChannel): with (yield from self.connect()) as conn: stdin, _, _ = yield from conn.open_session('sleep') result = yield from stdin.channel.make_x11_forwarding_request( '', 'xx', 0) yield from conn.wait_closed() self.assertFalse(result) @asynctest def test_unknown_action(self): """Test unknown action""" with (yield from self.connect()) as conn: result = yield from conn.run('unknown') self.assertEqual(result.exit_status, 255) yield from conn.wait_closed() asyncssh-1.11.1/tests/test_x509.py000066400000000000000000000241411320320510200166740ustar00rootroot00000000000000# Copyright (c) 2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for X.509 certificate handling""" import time import unittest from cryptography import x509 import asyncssh from .util import x509_available if x509_available: # pragma: no branch from asyncssh.crypto import X509Name, X509NamePattern from asyncssh.crypto import generate_x509_certificate from asyncssh.crypto import import_x509_certificate _purpose_secureShellClient = x509.ObjectIdentifier('1.3.6.1.5.5.7.3.21') @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509(unittest.TestCase): """Unit tests for X.509 module""" @classmethod def setUpClass(cls): cls._privkey = asyncssh.generate_private_key('ssh-rsa') cls._pubkey = cls._privkey.convert_to_public() cls._pubdata = cls._pubkey.export_public_key('pkcs8-der') def generate_certificate(self, subject='OU=name', issuer=None, serial=None, valid_after=0, valid_before=0xffffffffffffffff, ca=False, ca_path_len=None, purposes=None, user_principals=(), host_principals=(), hash_alg='sha256', comment=None): """Generate and check an X.509 certificate""" cert = generate_x509_certificate(self._privkey, self._pubkey, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_alg, comment) self.assertEqual(cert.data, import_x509_certificate(cert.data).data) self.assertEqual(cert.subject, X509Name(subject)) self.assertEqual(cert.issuer, X509Name(issuer if issuer else subject)) self.assertEqual(cert.key_data, self._pubdata) self.assertEqual(cert.comment, comment) return cert def test_generate(self): """Test X.509 certificate generation""" cert = self.generate_certificate(purposes='secureShellClient') self.assertEqual(cert.purposes, set((_purpose_secureShellClient,))) def test_generate_ca(self): """Test X.509 CA certificate generation""" self.generate_certificate(ca=True, ca_path_len=0) def test_serial(self): """Test X.509 certificate generation with serial number""" self.generate_certificate(serial=1) def test_user_principals(self): """Test X.509 certificate generation with user principals""" cert = self.generate_certificate(user_principals='user1,user2') self.assertEqual(cert.user_principals, ['user1', 'user2']) def test_host_principals(self): """Test X.509 certificate generation with host principals""" cert = self.generate_certificate(host_principals='host1,host2') self.assertEqual(cert.host_principals, ['host1', 'host2']) def test_principal_in_common_name(self): """Test X.509 certificate generation with user principals""" cert = self.generate_certificate(subject='CN=name') self.assertEqual(cert.user_principals, ['name']) self.assertEqual(cert.host_principals, ['name']) def test_unknown_hash(self): """Test X.509 certificate generation with unknown hash""" with self.assertRaises(ValueError): self.generate_certificate(hash_alg='xxx') def test_invalid_comment(self): """Test X.509 certificate generation with invalid comment""" with self.assertRaises(ValueError): self.generate_certificate(comment=b'\xff') def test_valid_self(self): """Test validation of X.509 self-signed certificate""" cert = self.generate_certificate() self.assertIsNone(cert.validate([cert], None, None, None)) def test_untrusted_self(self): """Test failed validation of untrusted X.509 self-signed certificate""" cert1 = self.generate_certificate() cert2 = self.generate_certificate() with self.assertRaises(ValueError): cert1.validate([cert2], None, None, None) def test_valid_chain(self): """Test validation of X.509 certificate chain""" root_ca = self.generate_certificate('OU=root', ca=True, ca_path_len=1) int_ca = self.generate_certificate('OU=int', 'OU=root', ca=True, ca_path_len=0) cert = self.generate_certificate('OU=user', 'OU=int') self.assertIsNone(cert.validate([int_ca, root_ca], None, None, None)) def test_incomplete_chain(self): """Test failed validation of incomplete X.509 certificate chain""" root_ca = self.generate_certificate('OU=root', ca=True, ca_path_len=1) int_ca = self.generate_certificate('OU=int', 'OU=root', ca=True, ca_path_len=0) cert = self.generate_certificate('OU=user', 'OU=int2') with self.assertRaises(ValueError): cert.validate([int_ca, root_ca], None, None, None) def test_not_yet_valid_self(self): """Test failed validation of not-yet-valid X.509 certificate""" cert = self.generate_certificate(valid_after=time.time() + 60) with self.assertRaises(ValueError): cert.validate([cert], None, None, None) def test_expired_self(self): """Test failed validation of expired X.509 certificate""" cert = self.generate_certificate(valid_before=time.time() - 60) with self.assertRaises(ValueError): cert.validate([cert], None, None, None) def test_expired_intermediate(self): """Test failed validation of expired X.509 intermediate CA""" root_ca = self.generate_certificate('OU=root', ca=True, ca_path_len=1) int_ca = self.generate_certificate('OU=int', 'OU=root', ca=True, ca_path_len=0, valid_before=time.time() - 60) cert = self.generate_certificate('OU=user', 'OU=int') with self.assertRaises(ValueError): cert.validate([int_ca, root_ca], None, None, None) def test_expired_root(self): """Test failed validation of expired X.509 root CA""" root_ca = self.generate_certificate('OU=root', ca=True, ca_path_len=1, valid_before=time.time() - 60) int_ca = self.generate_certificate('OU=int', 'OU=root', ca=True, ca_path_len=0) cert = self.generate_certificate('OU=user', 'OU=int') with self.assertRaises(ValueError): cert.validate([int_ca, root_ca], None, None, None) def test_purpose_mismatch(self): """Test failed validation due to purpose mismatch""" cert = self.generate_certificate(purposes='secureShellClient') with self.assertRaises(ValueError): cert.validate([cert], 'secureShellServer', None, None) def test_user_principal_match(self): """Test validation of user principal""" cert = self.generate_certificate(user_principals='user') self.assertIsNone(cert.validate([cert], None, 'user', None)) def test_user_principal_mismatch(self): """Test failed validation due to user principal mismatch""" cert = self.generate_certificate(user_principals='user1,user2') with self.assertRaises(ValueError): cert.validate([cert], None, 'user3', None) def test_host_principal_match(self): """Test validation of host principal""" cert = self.generate_certificate(host_principals='host') self.assertIsNone(cert.validate([cert], None, None, 'host')) def test_host_principal_mismatch(self): """Test failed validation due to host principal mismatch""" cert = self.generate_certificate(host_principals='host1,host2') with self.assertRaises(ValueError): cert.validate([cert], None, None, 'host3') def test_name(self): """Test X.509 distinguished name generation""" name = X509Name('O=Org,OU=Unit') self.assertEqual(name, X509Name('O=Org, OU=Unit')) self.assertEqual(name, X509Name(name)) self.assertEqual(name, X509Name(name.rdns)) self.assertEqual(len(name), 2) self.assertEqual(len(name.rdns), 2) self.assertEqual(str(name), 'O=Org,OU=Unit') self.assertNotEqual(name, X509Name('OU=Unit,O=Org')) def test_multiple_attrs_in_rdn(self): """Test multiple attributes in a relative distinguished name""" name1 = X509Name('O=Org,OU=Unit1+OU=Unit2') name2 = X509Name('O=Org,OU=Unit2+OU=Unit1') self.assertEqual(name1, name2) self.assertEqual(len(name1), 3) self.assertEqual(len(name1.rdns), 2) def test_invalid_attribute(self): """Test X.509 distinguished name with invalid attributes""" with self.assertRaises(ValueError): X509Name('xxx') with self.assertRaises(ValueError): X509Name('X=xxx') def test_exact_name_pattern(self): """Test X.509 distinguished name exact match""" pattern1 = X509NamePattern('O=Org,OU=Unit') pattern2 = X509NamePattern('O=Org, OU=Unit') self.assertEqual(pattern1, pattern2) self.assertEqual(hash(pattern1), hash(pattern2)) self.assertTrue(pattern1.matches(X509Name('O=Org,OU=Unit'))) self.assertFalse(pattern1.matches(X509Name('O=Org,OU=Unit2'))) def test_prefix_pattern(self): """Test X.509 distinguished name prefix match""" pattern = X509NamePattern('O=Org,*') self.assertTrue(pattern.matches(X509Name('O=Org,OU=Unit'))) self.assertFalse(pattern.matches(X509Name('O=Org2,OU=Unit'))) asyncssh-1.11.1/tests/util.py000066400000000000000000000224221320320510200161050ustar00rootroot00000000000000# Copyright (c) 2015-2017 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Utility functions for unit tests""" import asyncio import binascii import functools import os import subprocess import sys import tempfile import unittest from unittest.mock import patch # pylint: disable=unused-import try: import bcrypt bcrypt_available = hasattr(bcrypt, 'kdf') except ImportError: # pragma: no cover bcrypt_available = False try: import libnacl libnacl_available = True except (ImportError, OSError, AttributeError): # pragma: no cover libnacl_available = False try: from asyncssh.crypto import X509Name x509_available = True except ImportError: # pragma: no cover x509_available = False # pylint: enable=unused-import from asyncssh.constants import DISC_CONNECTION_LOST from asyncssh.gss import gss_available from asyncssh.misc import DisconnectError, SignalReceived, create_task from asyncssh.packet import String, UInt32, UInt64 def asynctest(func): """Decorator for async tests, for use with AsyncTestCase""" @functools.wraps(func) def async_wrapper(self, *args, **kwargs): """Run a function as a coroutine and wait for it to finish""" wrapped_func = asyncio.coroutine(func)(self, *args, **kwargs) return self.loop.run_until_complete(wrapped_func) return async_wrapper def asynctest35(func): """Decorator for Python 3.5 async tests, for use with AsyncTestCase""" @functools.wraps(func) def async_wrapper(self, *args, **kwargs): """Run a function as a coroutine and wait for it to finish""" wrapped_func = func(self, *args, **kwargs) return self.loop.run_until_complete(wrapped_func) return async_wrapper def patch_gss(cls): """Decorator for patching GSSAPI classes""" if not gss_available: # pragma: no cover return cls if sys.platform == 'win32': # pragma: no cover from .sspi_stub import SSPIAuth cls = patch('asyncssh.gss_win32.ClientAuth', SSPIAuth)(cls) cls = patch('asyncssh.gss_win32.ServerAuth', SSPIAuth)(cls) else: from .gssapi_stub import Name, Credentials, RequirementFlag from .gssapi_stub import SecurityContext cls = patch('asyncssh.gss_unix.Name', Name)(cls) cls = patch('asyncssh.gss_unix.Credentials', Credentials)(cls) cls = patch('asyncssh.gss_unix.RequirementFlag', RequirementFlag)(cls) cls = patch('asyncssh.gss_unix.SecurityContext', SecurityContext)(cls) return cls @asyncio.coroutine def echo(stdin, stdout, stderr=None): """Echo data from stdin back to stdout and stderr (if open)""" try: while not stdin.at_eof(): data = yield from stdin.read(65536) if data: stdout.write(data) if stderr: stderr.write(data) yield from stdout.drain() if stderr: yield from stderr.drain() stdout.write_eof() except SignalReceived as exc: if exc.signal == 'ABRT': raise DisconnectError(DISC_CONNECTION_LOST, 'Abort') else: stdin.channel.exit_with_signal(exc.signal) except OSError: pass stdout.close() def _encode_options(options): """Encode SSH certificate critical options and extensions""" return b''.join((String(k) + String(v) for k, v in options.items())) def make_certificate(cert_version, cert_type, key, signing_key, principals, key_id='name', valid_after=0, valid_before=0xffffffffffffffff, options=None, extensions=None, bad_signature=False): """Construct an SSH certificate""" keydata = key.encode_ssh_public() principals = b''.join((String(p) for p in principals)) options = _encode_options(options) if options else b'' extensions = _encode_options(extensions) if extensions else b'' signing_keydata = b''.join((String(signing_key.algorithm), signing_key.encode_ssh_public())) data = b''.join((String(cert_version), String(os.urandom(32)), keydata, UInt64(0), UInt32(cert_type), String(key_id), String(principals), UInt64(valid_after), UInt64(valid_before), String(options), String(extensions), String(''), String(signing_keydata))) if bad_signature: data += String('') else: data += String(signing_key.sign(data, signing_key.algorithm)) return b''.join((cert_version.encode('ascii'), b' ', binascii.b2a_base64(data))) def run(cmd): """Run a shell commands and return the output""" try: return subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: # pragma: no cover print(exc.output.decode()) raise class ConnectionStub: """Stub class used to replace an SSHConnection object""" def __init__(self, peer, server): self._peer = peer self._server = server if peer: self._packet_queue = asyncio.queues.Queue() self._queue_task = self.create_task(self._process_packets()) else: self._packet_queue = None self._queue_task = None @asyncio.coroutine def _run_task(self, coro): """Run an asynchronous task""" # pylint: disable=broad-except try: yield from coro except Exception as exc: if self._peer: # pragma: no branch self.queue_packet(exc) self.connection_lost(exc) def create_task(self, coro): """Create an asynchronous task""" return create_task(self._run_task(coro)) def is_client(self): """Return if this is a client connection""" return not self._server def is_server(self): """Return if this is a server connection""" return self._server def get_peer(self): """Return the peer of this connection""" return self._peer @asyncio.coroutine def _process_packets(self): """Process the queue of incoming packets""" while True: data = yield from self._packet_queue.get() if data is None or isinstance(data, Exception): self._queue_task = None self.connection_lost(data) break self.process_packet(data) def connection_lost(self, exc): """Handle the closing of a connection""" raise NotImplementedError def process_packet(self, data): """Process an incoming packet""" raise NotImplementedError def queue_packet(self, data): """Add an incoming packet to the queue""" self._packet_queue.put_nowait(data) def send_packet(self, *args): """Send a packet to this connection's peer""" if self._peer: self._peer.queue_packet(b''.join(args)) def close(self): """Close the connection, stopping processing of incoming packets""" if self._peer: self._peer.queue_packet(None) self._peer = None if self._queue_task: self.queue_packet(None) self._queue_task = None class TempDirTestCase(unittest.TestCase): """Unit test class which operates in a temporary directory""" tempdir = None @classmethod def setUpClass(cls): """Create temporary directory and set it as current directory""" cls._tempdir = tempfile.TemporaryDirectory() os.chdir(cls._tempdir.name) @classmethod def tearDownClass(cls): """Clean up temporary directory""" os.chdir('..') cls._tempdir.cleanup() class AsyncTestCase(TempDirTestCase): """Unit test class which supports tests using asyncio""" loop = None @classmethod def setUpClass(cls): """Set up event loop to run async tests and run async class setup""" super().setUpClass() cls.loop = asyncio.new_event_loop() asyncio.set_event_loop(cls.loop) try: # pylint: disable=no-member cls.loop.run_until_complete(cls.asyncSetUpClass()) except AttributeError: pass @classmethod def tearDownClass(cls): """Run async class teardown and close event loop""" try: # pylint: disable=no-member cls.loop.run_until_complete(cls.asyncTearDownClass()) except AttributeError: pass cls.loop.close() super().tearDownClass() def setUp(self): """Run async setup if any""" try: # pylint: disable=no-member self.loop.run_until_complete(self.asyncSetUp()) except AttributeError: pass def tearDown(self): """Run async teardown if any""" try: # pylint: disable=no-member self.loop.run_until_complete(self.asyncTearDown()) except AttributeError: pass asyncssh-1.11.1/tests_py35/000077500000000000000000000000001320320510200154345ustar00rootroot00000000000000asyncssh-1.11.1/tests_py35/test_connection.py000066400000000000000000000023411320320510200212040ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH connection API on Python 3.5 and later""" from tests.server import ServerTestCase from tests.util import asynctest, asynctest35 class _TestConnection(ServerTestCase): """Unit tests for AsyncSSH connection async context manager""" # pylint: disable=not-async-context-manager @asynctest35 async def test_connect(self): """Test connecting in Python 3.5 with async context manager""" async with self.connect(): pass @asynctest35 async def test_connect_await(self): """Test connecting with await and async context manager""" conn = await self.connect() async with conn: pass @asynctest def test_connect_yield(self): """Test connecting with yield from""" with (yield from self.connect()): pass asyncssh-1.11.1/tests_py35/test_process.py000066400000000000000000000024171320320510200205270ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH process API on Python 3.5 and later""" from tests.server import ServerTestCase from tests.util import asynctest35, echo class _TestStream(ServerTestCase): """Unit tests for AsyncSSH stream API""" # pylint: disable=not-async-context-manager @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(session_factory=echo) @asynctest35 async def test_shell(self): """Test starting a remote shell""" data = str(id(self)) async with self.connect() as conn: async with conn.create_process() as process: process.stdin.write(data) process.stdin.write_eof() result = await process.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) asyncssh-1.11.1/tests_py35/test_sftp.py000066400000000000000000000055521320320510200200300ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH SFTP client and server on Python 3.5 and later""" import os from tests.server import ServerTestCase from tests.util import asynctest, asynctest35 class _TestSFTP(ServerTestCase): """Unit tests for AsyncSSH SFTP async context manager""" # pylint: disable=not-async-context-manager @classmethod async def start_server(cls): """Start an SFTP server for the tests to use""" return await cls.create_server(sftp_factory=True) @asynctest35 async def test_sftp(self): """Test starting SFTP in Python 3.5 with async context manager""" async with self.connect() as conn: async with conn.start_sftp_client(): pass @asynctest35 async def test_sftp_await(self): """Test starting SFTP with await and async context manager""" async with self.connect() as conn: sftp = await conn.start_sftp_client() async with sftp: pass @asynctest def test_sftp_yield(self): """Test starting SFTP with yield from""" with (yield from self.connect()) as conn: with (yield from conn.start_sftp_client()): pass @asynctest35 async def test_sftp_open(self): """Test opening SFTP file in Python 3.5 with async context manager""" async with self.connect() as conn: async with conn.start_sftp_client() as sftp: try: async with sftp.open('file', 'w'): pass finally: os.unlink('file') @asynctest35 async def test_sftp_open_await(self): """Test opening SFTP file with await and async context manager""" async with self.connect() as conn: sftp = await conn.start_sftp_client() async with sftp: try: async with sftp.open('file', 'w'): pass finally: os.unlink('file') @asynctest def test_sftp_open_yield(self): """Test opening SFTP file with yield from""" with (yield from self.connect()) as conn: with (yield from conn.start_sftp_client()) as sftp: f = None try: f = yield from sftp.open('file', 'w') finally: if f: # pragma: no branch yield from f.close() os.unlink('file') asyncssh-1.11.1/tests_py35/test_stream.py000066400000000000000000000027721320320510200203500ustar00rootroot00000000000000# Copyright (c) 2016 by Ron Frederick . # All rights reserved. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v1.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-v10.html # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH stream API on Python 3.5 and later""" from tests.server import Server, ServerTestCase from tests.util import asynctest35, echo class _StreamServer(Server): """Server for testing the AsyncSSH stream API in Python 3.5""" def session_requested(self): """Handle a request to create a new session""" return echo class _TestStream(ServerTestCase): """Unit tests for AsyncSSH stream API""" # pylint: disable=not-async-context-manager @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(_StreamServer) @asynctest35 async def test_async_iterator(self): """Test reading lines by using SSHReader as an async iterator""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() data = ['Line 1\n', 'Line 2\n'] stdin.writelines(data) stdin.write_eof() async for line in stdout: self.assertEqual(line, data.pop(0)) self.assertEqual(data, []) asyncssh-1.11.1/tox.ini000066400000000000000000000007151320320510200147300ustar00rootroot00000000000000[tox] envlist = {py34,py35,py36}-{linux,macos,windows},py37-linux [testenv] deps = bcrypt coverage linux,macos: gssapi libnacl pyOpenSSL windows: pypiwin32 platform = linux: linux macos: darwin windows: win32 sitepackages = True skip_missing_interpreters = True usedevelop = True commands = {envpython} -m coverage run -p -m unittest py35,py36,py37: {envpython} -m coverage run -p -m unittest discover -s tests_py35