pax_global_header00006660000000000000000000000064147645041540014524gustar00rootroot0000000000000052 comment=88595a558963282534ef6a9dd871c3c0bb24bbfb elastic-transport-python-8.17.1/000077500000000000000000000000001476450415400166175ustar00rootroot00000000000000elastic-transport-python-8.17.1/.github/000077500000000000000000000000001476450415400201575ustar00rootroot00000000000000elastic-transport-python-8.17.1/.github/workflows/000077500000000000000000000000001476450415400222145ustar00rootroot00000000000000elastic-transport-python-8.17.1/.github/workflows/backport.yml000066400000000000000000000013001476450415400245360ustar00rootroot00000000000000name: Backport on: pull_request_target: types: - closed - labeled jobs: backport: name: Backport runs-on: ubuntu-latest # Only react to merged PRs for security reasons. # See https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#pull_request_target. if: > github.event.pull_request.merged && ( github.event.action == 'closed' || ( github.event.action == 'labeled' && contains(github.event.label.name, 'backport') ) ) steps: - uses: tibdex/backport@9565281eda0731b1d20c4025c43339fb0a23812e # v2.0.4 with: github_token: ${{ secrets.GITHUB_TOKEN }} elastic-transport-python-8.17.1/.github/workflows/ci.yml000066400000000000000000000042601476450415400233340ustar00rootroot00000000000000--- name: CI on: [push, pull_request] env: FORCE_COLOR: 1 jobs: package: runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v1 - name: Set up Python 3.x uses: actions/setup-python@v5 with: python-version: 3.x - name: Install dependencies run: python3 -m pip install setuptools wheel twine - name: Build dists run: python3 utils/build-dists.py lint: runs-on: ubuntu-latest steps: - name: Checkout Repository uses: actions/checkout@v1 - name: Set up Python 3.x uses: actions/setup-python@v5 with: python-version: 3.x - name: Install dependencies run: python3 -m pip install nox - name: Lint the code run: nox -s lint env: # Workaround for development versions # https://github.com/aio-libs/aiohttp/issues/7675 AIOHTTP_NO_EXTENSIONS: 1 test: strategy: fail-fast: false matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] os: ["ubuntu-latest"] experimental: [false] nox-session: [''] include: - python-version: "3.8" os: "ubuntu-latest" experimental: false nox-session: "test-min-deps" runs-on: ${{ matrix.os }} name: test-${{ matrix.python-version }} ${{ matrix.nox-session }} continue-on-error: ${{ matrix.experimental }} steps: - name: Checkout repository uses: actions/checkout@v2 - name: Set up Python - ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} allow-prereleases: true - name: Install Dependencies run: python -m pip install --upgrade nox - name: Run tests run: nox -s ${NOX_SESSION:-test-$PYTHON_VERSION} env: PYTHON_VERSION: ${{ matrix.python-version }} NOX_SESSION: ${{ matrix.nox-session }} # Required for development versions of Python AIOHTTP_NO_EXTENSIONS: 1 FROZENLIST_NO_EXTENSIONS: 1 YARL_NO_EXTENSIONS: 1 elastic-transport-python-8.17.1/.gitignore000066400000000000000000000040431476450415400206100ustar00rootroot00000000000000# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/sphinx/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # sample code for GitHub issues issues/ elastic-transport-python-8.17.1/.readthedocs.yml000066400000000000000000000010271476450415400217050ustar00rootroot00000000000000version: 2 build: os: ubuntu-22.04 tools: # To work around https://github.com/aio-libs/aiohttp/issues/7675, we need # to set AIOHTTP_NO_EXTENSIONS to 1 but it has to be done in # https://readthedocs.org/dashboard/elastic-transport-python/environmentvariables/ # because of https://github.com/readthedocs/readthedocs.org/issues/6311 python: "3" python: install: - method: pip path: . extra_requirements: - develop sphinx: configuration: docs/sphinx/conf.py fail_on_warning: true elastic-transport-python-8.17.1/CHANGELOG.md000066400000000000000000000155631476450415400204420ustar00rootroot00000000000000# Changelog ## 8.17.1 (2025-03-12) * Ensure compatibility with httpx v0.28.0+ ([#222](https://github.com/elastic/elastic-transport-python/pull/222), contributed by Arch Linux maintainer @carlsmedstad) * Add missing NOTICE file ([#229](https://github.com/elastic/elastic-transport-python/pull/229), reported by Debian Maintainer @schoekek) ## 8.17.0 (2025-01-07) * Fix use of SSLContext with sniffing ([#199](https://github.com/elastic/elastic-transport-python/pull/199)) * Fix enabled_cleanup_closed warning ([#202](https://github.com/elastic/elastic-transport-python/pull/202)) * Remove unneeded install requirement ([#196](https://github.com/elastic/elastic-transport-python/pull/196)) * Fix aiohttp call type: ignore differently ([#190](https://github.com/elastic/elastic-transport-python/pull/190)) ## 8.15.1 (2024-10-09) * Add explicit Python 3.13 support ([#189](https://github.com/elastic/elastic-transport-python/pull/189)) ## 8.15.0 (2024-08-09) * Removed call to `raise_for_status()` when using `HttpxAsyncHttpNode` to prevent exceptions being raised for 404 responses ([#182](https://github.com/elastic/elastic-transport-python/pull/182)) * Documented response classes ([#175](https://github.com/elastic/elastic-transport-python/pull/175)) * Dropped support for Python 3.7 ([#179](https://github.com/elastic/elastic-transport-python/pull/179)) ## 8.13.1 (2024-04-28) - Fixed requests 2.32 compatibility (#164) - Fixed TypeError when two nodes are declared dead at the same time (#115, contributed by @floxay) - Added `TransportApiResponse` (#160, #161, contributed by @JessicaGarson) ## 8.13.0 - Added support for the HTTPX client with asyncio (#137, contributed by @b4sus) - Added optional orjson serializer support (#152) ## 8.12.0 - Fix basic auth built from percent-encoded URLs (#143) ## 8.11.0 - Always set default HTTPS port to 443 (#127) - Drop support for Python 3.6 (#109) - Include tests in sdist (#122, contributed by @parona-source) - Fix `__iter__` return type to Iterator (#129, contributed by @altescy) ## 8.10.0 - Support urllib3 2.x in addition to urllib3 1.26.x ([#121](https://github.com/elastic/elastic-transport-python/pull/121)) - Add 409 to `NOT_DEAD_NODE_HTTP_STATUSES` ([#120](https://github.com/elastic/elastic-transport-python/pull/120)) ## 8.4.1 - Fixed an issue where a large number of consecutive failures to connect to a node would raise an `OverflowError`. - Fixed an issue to ensure that `ApiResponse` can be pickled. ## 8.4.0 ### Added - Added method for clients to use default ports for URL scheme. ## 8.1.2 ### Fixed - Fixed issue when connecting to an IP address with HTTPS enabled would result in a `ValueError` for a mismatch between `check_hostname` and `server_hostname`. ## 8.1.1 ### Fixed - Fixed `JsonSerializer` to return `None` if a response using `Content-Type: application/json` is empty instead of raising an error. ## 8.1.0 ### Fixed - Fixed `Urllib3HttpNode` and `RequestsHttpNode` to never require a valid certificate chain when using `ssl_assert_fingerprint`. Instead the internal HTTP client libraries will explicitly disable verifying the certificate chain and instead rely only on the certificate fingerprint for verification. ## 8.0.1 ### Fixed - Fixed `AiohttpHttpNode` to close TLS connections that aren't properly shutdown by the server instead of leaking them - Fixed `Urllib3HttpNode` to respect `path_prefix` setting in `NodeConfig` ## 8.0.0 ### Added - Added support for asyncio with `AsyncTransport` and `AiohttpHttpNode` - Added `JsonSerializer`, `NdjsonSerializer` - Added `connections_per_node` parameter to `RequestsHttpNode` - Added support for `ssl_assert_fingerprint` to `RequestsHttpNode` - Added **experimental** support for pinning non-leaf certificates via `ssl_assert_fingerprint` when using CPython 3.10+ - Added support for node discovery via "sniffing" using the `sniff_callback` transport parameter - Added ability to specify `ssl_version` via `ssl.TLSVersion` enum instead of `ssl.PROTOCOL_TLSvX` for Python 3.7+ - Added `elastic_transport.client_utils` module to help writing API clients - Added `elastic_transport.debug_logging` method to enable all logging for debugging purposes - Added option to set `requests.Session.auth` within `RequestsHttpNode` via `NodeConfig._extras['requests.session.auth']` ### Changed - Changed `*Connection` classes to use `*Node` terminology - Changed `connection_class` to `node_class` - Changed `ConnectionPool` to `NodePool` - Changed `ConnectionSelector` to `NodeSelector` - Changed `NodeSelector(randomize_hosts)` parameter to `randomize_nodes` - Changed `NodeSelector.get_connection()` method to `get()` - Changed `elastic_transport.connection` logger name to `elastic_transport.node` - Changed `Urllib3HttpNode(connections_per_host)` parameter to `connections_per_node` - Changed return type of `BaseNode.perform_request()` to `NamedTuple(meta=ApiResponseMeta, body=bytes)` - Changed return type of `Transport.perform_request()` to `NamedTuple(meta=ApiResponseMeta, body=Any)` - Changed name of `Deserializer` into `SerializersCollection` - Changed `ssl_version` to denote the minimum TLS version instead of the only TLS version - Changed the base class for `ApiError` to be `Exception` instead of `TransportError`. `TransportError` is now only for errors that occur at the transport layer. - Changed `Urllib3HttpNode` to block on new connections when the internal connection pool is exhausted ### Removed - Removed support for Python 2.7 - Removed `DummyConnectionPool` and `EmptyConnectionPool` in favor of `NodePool`. ### Fixed - Fixed a work-around with `AiohttpHttpNode` where `method="HEAD"` requests wouldn't mark the internal connection as reusable. This work-around is no longer needed when `aiohttp>=3.7.0` is installed. - Fixed logic for splitting `aiohttp.__version__` when determining if `HEAD` bug is fixed. ## 7.15.0 (2021-09-20) Release created to be compatible with 7.15 clients ## 7.14.0 (2021-08-02) Release created to be compatible with 7.14 clients ## 7.13.0 (2021-05-24) Release created to be compatible with 7.13 clients ## 7.12.0 (2021-03-22) Release created to be compatible with 7.12 clients ## 7.11.0 (2021-02-10) ### Added - Added the `X-Elastic-Client-Meta` HTTP header ([PR #4](https://github.com/elastic/elastic-transport-python/pull/4)) - Added HTTP response headers to `Response` and `TransportError` ([PR #5](https://github.com/elastic/elastic-transport-python/pull/5)) - Added the `QueryParams` data structure for representing an ordered sequence of key-value pairs for the URL query ([PR #6](https://github.com/elastic/elastic-transport-python/pull/6)) ### Changed - Changed `Connection.perform_request()` to take `target` instead of `path` and `params`. Instead `path` and `params` are created within `Transport.perform_request()` ([PR #6](https://github.com/elastic/elastic-transport-python/pull/6)) ## 0.1.0b0 (2020-10-21) - Initial beta release of `elastic-transport-python` elastic-transport-python-8.17.1/LICENSE000066400000000000000000000236371476450415400176370ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. elastic-transport-python-8.17.1/MANIFEST.in000066400000000000000000000003341476450415400203550ustar00rootroot00000000000000include LICENSE include MANIFEST.in include README.md include CHANGELOG.md include setup.py include elastic_transport/py.typed graft tests prune docs/_build recursive-exclude * __pycache__ recursive-exclude * *.py[co] elastic-transport-python-8.17.1/NOTICE000066400000000000000000000001071476450415400175210ustar00rootroot00000000000000Elastic Transport Library for Python Copyright 2025 Elasticsearch B.V. elastic-transport-python-8.17.1/README.md000066400000000000000000000025221476450415400200770ustar00rootroot00000000000000# elastic-transport-python [![PyPI](https://img.shields.io/pypi/v/elastic-transport)](https://pypi.org/project/elastic-transport) [![Python Versions](https://img.shields.io/pypi/pyversions/elastic-transport)](https://pypi.org/project/elastic-transport) [![PyPI Downloads](https://static.pepy.tech/badge/elastic-transport)](https://pepy.tech/project/elastic-transport) [![CI Status](https://img.shields.io/github/actions/workflow/status/elastic/elastic-transport-python/ci.yml)](https://github.com/elastic/elastic-transport-python/actions) Transport classes and utilities shared among Python Elastic client libraries This library was lifted from [`elasticsearch-py`](https://github.com/elastic/elasticsearch-py) and then transformed to be used across all Elastic services rather than only Elasticsearch. ### Installing from PyPI ``` $ python -m pip install elastic-transport ``` Versioning follows the major and minor version of the Elastic Stack version and the patch number is incremented for bug fixes within a minor release. ## Documentation Documentation including an API reference is available on [Read the Docs](https://elastic-transport-python.readthedocs.io). ## License `elastic-transport-python` is available under the Apache-2.0 license. For more details see [LICENSE](https://github.com/elastic/elastic-transport-python/blob/main/LICENSE). elastic-transport-python-8.17.1/docs/000077500000000000000000000000001476450415400175475ustar00rootroot00000000000000elastic-transport-python-8.17.1/docs/sphinx/000077500000000000000000000000001476450415400210605ustar00rootroot00000000000000elastic-transport-python-8.17.1/docs/sphinx/client_utils.rst000066400000000000000000000003371476450415400243130ustar00rootroot00000000000000Client Utilities ================ Reusable utilities for creating API clients using ``elastic_transport``. .. py:currentmodule:: elastic_transport.client_utils .. automodule:: elastic_transport.client_utils :members: elastic-transport-python-8.17.1/docs/sphinx/conf.py000066400000000000000000000027211476450415400223610ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import datetime import os import sys sys.path.insert(0, os.path.abspath("../..")) from elastic_transport import __version__ # noqa project = "elastic-transport" copyright = f"{datetime.date.today().year} Elasticsearch B.V." author = "Seth Michael Larson" version = __version__ release = __version__ extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "sphinx_autodoc_typehints", ] pygments_style = "sphinx" pygments_dark_style = "monokai" templates_path = [] exclude_patterns = [] html_theme = "furo" html_static_path = [] intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "requests": ("https://docs.python-requests.org/en/latest", None), } elastic-transport-python-8.17.1/docs/sphinx/exceptions.rst000066400000000000000000000010441476450415400237720ustar00rootroot00000000000000Exceptions & Warnings ===================== .. py:currentmodule:: elastic_transport Transport Errors ---------------- .. autoclass:: TransportError :members: .. autoclass:: TlsError :members: .. autoclass:: ConnectionError :members: .. autoclass:: ConnectionTimeout :members: .. autoclass:: SerializationError :members: .. autoclass:: SniffingError :members: .. autoclass:: ApiError :members: Warnings -------- .. py:currentmodule:: elastic_transport .. autoclass:: TransportWarning .. autoclass:: SecurityWarning elastic-transport-python-8.17.1/docs/sphinx/index.rst000066400000000000000000000002631476450415400227220ustar00rootroot00000000000000API Reference ============= .. toctree:: :maxdepth: 2 installation nodes responses exceptions logging transport node_pool serializers client_utils elastic-transport-python-8.17.1/docs/sphinx/installation.rst000066400000000000000000000007101476450415400243110ustar00rootroot00000000000000Installation ============ Install with ``pip`` like so: ``$ python -m pip install elastic-transport`` Additional dependencies are required to use some features of the ``elastic-transport`` package. Install the ``requests`` package to use :class:`elastic_transport.RequestsHttpNode`. Install the ``aiohttp`` package to use :class:`elastic_transport.AiohttpHttpNode`. Install the ``httpx`` package to use :class:`elastic_transport.HttpxAsyncHttpNode`. elastic-transport-python-8.17.1/docs/sphinx/logging.rst000066400000000000000000000046261476450415400232500ustar00rootroot00000000000000Logging ======= .. py:currentmodule:: elastic_transport Available loggers ----------------- - ``elastic_transport.node_pool``: Logs activity within the :class:`elastic_transport.NodePool` like nodes switching between "alive" and "dead" - ``elastic_transport.transport``: Logs requests and responses in addition to retries, errors, and sniffing. - ``elastic_transport.node``: Logs all network activity for individual :class:`elastic_transport.BaseNode` instances. This logger is recommended only for human debugging as the logs are unstructured and meant primarily for human consumption from the command line. Debugging requests and responses -------------------------------- .. autofunction:: elastic_transport.debug_logging .. warning:: This method shouldn't be enabled in production as it's extremely verbose. Should only be used for debugging manually. .. code-block:: python import elastic_transport from elasticsearch import Elasticsearch # In this example we're debugging an Elasticsearch client: client = Elasticsearch(...) # Use `elastic_transport.debug_logging()` before the request elastic_transport.debug_logging() client.search( index="example-index", query={ "match": { "text-field": "value" } }, typed_keys=True ) The following script will output these logs about the HTTP request and response: .. code-block:: [2021-11-23T14:11:20] > POST /example-index/_search?typed_keys=true HTTP/1.1 > Accept: application/json > Accept-Encoding: gzip > Authorization: Basic > Connection: keep-alive > Content-Encoding: gzip > Content-Type: application/json > User-Agent: elastic-transport-python/8.1.0+dev > X-Elastic-Client-Meta: es=8.1.0p,py=3.9.2,t=8.1.0p,ur=1.26.7 > {"query":{"match":{"text-field":"value"}}} < HTTP/1.1 200 OK < Content-Encoding: gzip < Content-Length: 165 < Content-Type: application/json;charset=utf-8 < Date: Tue, 23 Nov 2021 20:11:20 GMT < X-Cloud-Request-Id: ctSE59hPSCugrCPM4A2GUQ < X-Elastic-Product: Elasticsearch < X-Found-Handling-Cluster: 40c9b5837c8f4dd083f05eac950fd50c < X-Found-Handling-Instance: instance-0000000001 < {"hits":{...}} Notice how the ``Authorization`` HTTP header is hidden and the complete HTTP request and response method, target, headers, status, and bodies are logged for debugging. elastic-transport-python-8.17.1/docs/sphinx/node_pool.rst000066400000000000000000000003551476450415400235730ustar00rootroot00000000000000Node Pool ========= .. py:currentmodule:: elastic_transport .. autoclass:: NodePool :members: Node selectors -------------- .. autoclass:: NodeSelector :members: .. autoclass:: RandomSelector .. autoclass:: RoundRobinSelector elastic-transport-python-8.17.1/docs/sphinx/nodes.rst000066400000000000000000000027031476450415400227240ustar00rootroot00000000000000Nodes ===== .. py:currentmodule:: elastic_transport Configuring nodes ----------------- .. autoclass:: elastic_transport::NodeConfig :members: Node classes ------------ .. autoclass:: Urllib3HttpNode :members: .. autoclass:: RequestsHttpNode :members: .. autoclass:: AiohttpHttpNode :members: .. autoclass:: HttpxAsyncHttpNode :members: Custom node classes ------------------- You can define your own node class like so: .. code-block:: python from typing import Optional from elastic_transport import Urllib3HttpNode, NodeConfig, ApiResponseMeta, HttpHeaders from elastic_transport.client_utils import DefaultType, DEFAULT class CustomHttpNode(Urllib3HttpNode): def perform_request( self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> Tuple[ApiResponseMeta, bytes]: # Define your HTTP request method here... and once you have a custom node class you can pass the class to :class:`elastic_transport.Transport` or an API client like so: .. code-block:: python # Example using a Transport instance: from elastic_transport import Transport transport = Transport(..., node_class=CustomHttpNode) # Example using an API client: from elasticsearch import Elasticsearch client = Elasticsearch(..., node_class=CustomHttpNode) elastic-transport-python-8.17.1/docs/sphinx/responses.rst000066400000000000000000000012031476450415400236270ustar00rootroot00000000000000Responses ========= .. py:currentmodule:: elastic_transport Response headers ---------------- .. autoclass:: elastic_transport::HttpHeaders :members: freeze Metadata -------- .. autoclass:: ApiResponseMeta :members: Response classes ---------------- .. autoclass:: ApiResponse :members: .. autoclass:: BinaryApiResponse :members: :show-inheritance: .. autoclass:: HeadApiResponse :members: :show-inheritance: .. autoclass:: ListApiResponse :members: :show-inheritance: .. autoclass:: ObjectApiResponse :members: :show-inheritance: .. autoclass:: TextApiResponse :members: :show-inheritance: elastic-transport-python-8.17.1/docs/sphinx/serializers.rst000066400000000000000000000004351476450415400241500ustar00rootroot00000000000000Serializers =========== .. py:currentmodule:: elastic_transport .. autoclass:: Serializer :members: .. autoclass:: JsonSerializer :members: .. autoclass:: OrjsonSerializer :members: .. autoclass:: TextSerializer :members: .. autoclass:: NdjsonSerializer :members: elastic-transport-python-8.17.1/docs/sphinx/transport.rst000066400000000000000000000002201476450415400236400ustar00rootroot00000000000000Transport ========= .. py:currentmodule:: elastic_transport .. autoclass:: Transport :members: .. autoclass:: AsyncTransport :members: elastic-transport-python-8.17.1/elastic_transport/000077500000000000000000000000001476450415400223575ustar00rootroot00000000000000elastic-transport-python-8.17.1/elastic_transport/__init__.py000066400000000000000000000077521476450415400245030ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Transport classes and utilities shared among Python Elastic client libraries""" import logging from ._async_transport import AsyncTransport as AsyncTransport from ._exceptions import ( ApiError, ConnectionError, ConnectionTimeout, SecurityWarning, SerializationError, SniffingError, TlsError, TransportError, TransportWarning, ) from ._models import ApiResponseMeta, HttpHeaders, NodeConfig, SniffOptions from ._node import ( AiohttpHttpNode, BaseAsyncNode, BaseNode, HttpxAsyncHttpNode, RequestsHttpNode, Urllib3HttpNode, ) from ._node_pool import NodePool, NodeSelector, RandomSelector, RoundRobinSelector from ._otel import OpenTelemetrySpan from ._response import ApiResponse as ApiResponse from ._response import BinaryApiResponse as BinaryApiResponse from ._response import HeadApiResponse as HeadApiResponse from ._response import ListApiResponse as ListApiResponse from ._response import ObjectApiResponse as ObjectApiResponse from ._response import TextApiResponse as TextApiResponse from ._serializer import ( JsonSerializer, NdjsonSerializer, Serializer, SerializerCollection, TextSerializer, ) from ._transport import Transport as Transport from ._transport import TransportApiResponse from ._utils import fixup_module_metadata from ._version import __version__ as __version__ # noqa __all__ = [ "AiohttpHttpNode", "ApiError", "ApiResponse", "ApiResponseMeta", "AsyncTransport", "BaseAsyncNode", "BaseNode", "BinaryApiResponse", "ConnectionError", "ConnectionTimeout", "HeadApiResponse", "HttpHeaders", "HttpxAsyncHttpNode", "JsonSerializer", "ListApiResponse", "NdjsonSerializer", "NodeConfig", "NodePool", "NodeSelector", "ObjectApiResponse", "OpenTelemetrySpan", "RandomSelector", "RequestsHttpNode", "RoundRobinSelector", "SecurityWarning", "SerializationError", "Serializer", "SerializerCollection", "SniffOptions", "SniffingError", "TextApiResponse", "TextSerializer", "TlsError", "Transport", "TransportApiResponse", "TransportError", "TransportWarning", "Urllib3HttpNode", ] try: from elastic_transport._serializer import OrjsonSerializer # noqa: F401 __all__.append("OrjsonSerializer") except ImportError: pass _logger = logging.getLogger("elastic_transport") _logger.addHandler(logging.NullHandler()) del _logger fixup_module_metadata(__name__, globals()) del fixup_module_metadata def debug_logging() -> None: """Enables logging on all ``elastic_transport.*`` loggers and attaches a :class:`logging.StreamHandler` instance to each. This is an easy way to visualize the network activity occurring on the client or debug a client issue. """ handler = logging.StreamHandler() formatter = logging.Formatter( "[%(asctime)s] %(message)s", datefmt="%Y-%m-%dT%H:%M:%S" ) handler.setFormatter(formatter) for logger in ( logging.getLogger("elastic_transport.node"), logging.getLogger("elastic_transport.node_pool"), logging.getLogger("elastic_transport.transport"), ): logger.addHandler(handler) logger.setLevel(logging.DEBUG) elastic-transport-python-8.17.1/elastic_transport/_async_transport.py000066400000000000000000000473661476450415400263410ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import logging from typing import ( Any, Awaitable, Callable, Collection, List, Mapping, Optional, Tuple, Type, Union, ) from ._compat import await_if_coro from ._exceptions import ( ConnectionError, ConnectionTimeout, SniffingError, TransportError, ) from ._models import DEFAULT, DefaultType, HttpHeaders, NodeConfig, SniffOptions from ._node import AiohttpHttpNode, BaseAsyncNode from ._node_pool import NodePool, NodeSelector from ._otel import OpenTelemetrySpan from ._serializer import Serializer from ._transport import ( DEFAULT_CLIENT_META_SERVICE, NOT_DEAD_NODE_HTTP_STATUSES, Transport, TransportApiResponse, validate_sniffing_options, ) from .client_utils import resolve_default _logger = logging.getLogger("elastic_transport.transport") class AsyncTransport(Transport): """ Encapsulation of transport-related to logic. Handles instantiation of the individual nodes as well as creating a node pool to hold them. Main interface is the :meth:`elastic_transport.Transport.perform_request` method. """ def __init__( self, node_configs: List[NodeConfig], node_class: Union[str, Type[BaseAsyncNode]] = AiohttpHttpNode, node_pool_class: Type[NodePool] = NodePool, randomize_nodes_in_pool: bool = True, node_selector_class: Optional[Union[str, Type[NodeSelector]]] = None, dead_node_backoff_factor: Optional[float] = None, max_dead_node_backoff: Optional[float] = None, serializers: Optional[Mapping[str, Serializer]] = None, default_mimetype: str = "application/json", max_retries: int = 3, retry_on_status: Collection[int] = (429, 502, 503, 504), retry_on_timeout: bool = False, sniff_on_start: bool = False, sniff_before_requests: bool = False, sniff_on_node_failure: bool = False, sniff_timeout: Optional[float] = 0.5, min_delay_between_sniffing: float = 10.0, sniff_callback: Optional[ Callable[ ["AsyncTransport", "SniffOptions"], Union[List[NodeConfig], Awaitable[List[NodeConfig]]], ] ] = None, meta_header: bool = True, client_meta_service: Tuple[str, str] = DEFAULT_CLIENT_META_SERVICE, ): """ :arg node_configs: List of 'NodeConfig' instances to create initial set of nodes. :arg node_class: subclass of :class:`~elastic_transport.BaseNode` to use or the name of the Connection (ie 'urllib3', 'requests') :arg node_pool_class: subclass of :class:`~elastic_transport.NodePool` to use :arg randomize_nodes_in_pool: Set to false to not randomize nodes within the pool. Defaults to true. :arg node_selector_class: Class to be used to select nodes within the :class:`~elastic_transport.NodePool`. :arg dead_node_backoff_factor: Exponential backoff factor to calculate the amount of time to timeout a node after an unsuccessful API call. :arg max_dead_node_backoff: Maximum amount of time to timeout a node after an unsuccessful API call. :arg serializers: optional dict of serializer instances that will be used for deserializing data coming from the server. (key is the mimetype) :arg max_retries: Maximum number of retries for an API call. Set to 0 to disable retries. Defaults to ``0``. :arg retry_on_status: set of HTTP status codes on which we should retry on a different node. defaults to ``(429, 502, 503, 504)`` :arg retry_on_timeout: should timeout trigger a retry on different node? (default ``False``) :arg sniff_on_start: If ``True`` will sniff for additional nodes as soon as possible, guaranteed before the first request. :arg sniff_on_node_failure: If ``True`` will sniff for additional nodees after a node is marked as dead in the pool. :arg sniff_before_requests: If ``True`` will occasionally sniff for additional nodes as requests are sent. :arg sniff_timeout: Timeout value in seconds to use for sniffing requests. Defaults to 1 second. :arg min_delay_between_sniffing: Number of seconds to wait between calls to :meth:`elastic_transport.Transport.sniff` to avoid sniffing too frequently. Defaults to 10 seconds. :arg sniff_callback: Function that is passed a :class:`elastic_transport.Transport` and :class:`elastic_transport.SniffOptions` and should do node discovery and return a list of :class:`elastic_transport.NodeConfig` instances or a coroutine that returns the list. """ # Since we don't pass all the sniffing options to super().__init__() # we want to validate the sniffing options here too. validate_sniffing_options( node_configs=node_configs, sniff_on_start=sniff_on_start, sniff_before_requests=sniff_before_requests, sniff_on_node_failure=sniff_on_node_failure, sniff_callback=sniff_callback, ) super().__init__( node_configs=node_configs, node_class=node_class, node_pool_class=node_pool_class, randomize_nodes_in_pool=randomize_nodes_in_pool, node_selector_class=node_selector_class, dead_node_backoff_factor=dead_node_backoff_factor, max_dead_node_backoff=max_dead_node_backoff, serializers=serializers, default_mimetype=default_mimetype, max_retries=max_retries, retry_on_status=retry_on_status, retry_on_timeout=retry_on_timeout, sniff_timeout=sniff_timeout, min_delay_between_sniffing=min_delay_between_sniffing, meta_header=meta_header, client_meta_service=client_meta_service, ) self._sniff_on_start = sniff_on_start self._sniff_before_requests = sniff_before_requests self._sniff_on_node_failure = sniff_on_node_failure self._sniff_timeout = sniff_timeout self._sniff_callback = sniff_callback # type: ignore self._sniffing_task: Optional["asyncio.Task[Any]"] = None self._last_sniffed_at = 0.0 # We set this to 'None' here but it'll never be None by the # time it's needed. Gets set within '_async_call()' which should # precede all logic within async calls. self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment] # AsyncTransport doesn't require a thread lock for # sniffing. Uses '_sniffing_task' instead. self._sniffing_lock = None # type: ignore[assignment] async def perform_request( # type: ignore[override, return] self, method: str, target: str, *, body: Optional[Any] = None, headers: Union[Mapping[str, Any], DefaultType] = DEFAULT, max_retries: Union[int, DefaultType] = DEFAULT, retry_on_status: Union[Collection[int], DefaultType] = DEFAULT, retry_on_timeout: Union[bool, DefaultType] = DEFAULT, request_timeout: Union[Optional[float], DefaultType] = DEFAULT, client_meta: Union[Tuple[Tuple[str, str], ...], DefaultType] = DEFAULT, otel_span: Union[OpenTelemetrySpan, DefaultType] = DEFAULT, ) -> TransportApiResponse: """ Perform the actual request. Retrieve a node from the node pool, pass all the information to it's perform_request method and return the data. If an exception was raised, mark the node as failed and retry (up to ``max_retries`` times). If the operation was successful and the node used was previously marked as dead, mark it as live, resetting it's failure count. :arg method: HTTP method to use :arg target: HTTP request target :arg body: body of the request, will be serialized using serializer and passed to the node :arg headers: Additional headers to send with the request. :arg max_retries: Maximum number of retries before giving up on a request. Set to ``0`` to disable retries. :arg retry_on_status: Collection of HTTP status codes to retry. :arg retry_on_timeout: Set to true to retry after timeout errors. :arg request_timeout: Amount of time to wait for a response to fail with a timeout error. :arg client_meta: Extra client metadata key-value pairs to send in the client meta header. :arg otel_span: OpenTelemetry span used to add metadata to the span. :returns: Tuple of the :class:`elastic_transport.ApiResponseMeta` with the deserialized response. """ await self._async_call() if headers is DEFAULT: request_headers = HttpHeaders() else: request_headers = HttpHeaders(headers) max_retries = resolve_default(max_retries, self.max_retries) retry_on_timeout = resolve_default(retry_on_timeout, self.retry_on_timeout) retry_on_status = resolve_default(retry_on_status, self.retry_on_status) otel_span = resolve_default(otel_span, OpenTelemetrySpan(None)) if self.meta_header: request_headers["x-elastic-client-meta"] = ",".join( f"{k}={v}" for k, v in self._transport_client_meta + resolve_default(client_meta, ()) ) # Serialize the request body to bytes based on the given mimetype. request_body: Optional[bytes] if body is not None: if "content-type" not in request_headers: raise ValueError( "Must provide a 'Content-Type' header to requests with bodies" ) request_body = self.serializers.dumps( body, mimetype=request_headers["content-type"] ) otel_span.set_db_statement(request_body) else: request_body = None # Errors are stored from (oldest->newest) errors: List[Exception] = [] for attempt in range(max_retries + 1): # If we sniff before requests are made we want to do so before # 'node_pool.get()' is called so our sniffed nodes show up in the pool. if self._sniff_before_requests: await self.sniff(False) retry = False node_failure = False last_response: Optional[TransportApiResponse] = None node: BaseAsyncNode = self.node_pool.get() # type: ignore[assignment] start_time = self._loop.time() try: otel_span.set_node_metadata(node.host, node.port, node.base_url, target) resp = await node.perform_request( method, target, body=request_body, headers=request_headers, request_timeout=request_timeout, ) _logger.info( "%s %s%s [status:%s duration:%.3fs]" % ( method, node.base_url, target, resp.meta.status, self._loop.time() - start_time, ) ) if method != "HEAD": body = self.serializers.loads(resp.body, resp.meta.mimetype) else: body = None if resp.meta.status in retry_on_status: retry = True # Keep track of the last response we see so we can return # it in case the retried request returns with a transport error. last_response = TransportApiResponse(resp.meta, body) except TransportError as e: _logger.info( "%s %s%s [status:%s duration:%.3fs]" % ( method, node.base_url, target, "N/A", self._loop.time() - start_time, ) ) if isinstance(e, ConnectionTimeout): retry = retry_on_timeout node_failure = True elif isinstance(e, ConnectionError): retry = True node_failure = True # If the error was determined to be a node failure # we mark it dead in the node pool to allow for # other nodes to be retried. if node_failure: self.node_pool.mark_dead(node) if self._sniff_on_node_failure: try: await self.sniff(False) except TransportError: # If sniffing on failure, it could fail too. Catch the # exception not to interrupt the retries. pass if not retry or attempt >= max_retries: # Since we're exhausted but we have previously # received some sort of response from the API # we should forward that along instead of the # transport error. Likely to be more actionable. if last_response is not None: return last_response e.errors = tuple(errors) raise else: _logger.warning( "Retrying request after failure (attempt %d of %d)", attempt, max_retries, exc_info=e, ) errors.append(e) else: # If we got back a response we need to check if that status # is indicative of a healthy node even if it's a non-2XX status if ( 200 <= resp.meta.status < 299 or resp.meta.status in NOT_DEAD_NODE_HTTP_STATUSES ): self.node_pool.mark_live(node) else: self.node_pool.mark_dead(node) if self._sniff_on_node_failure: try: await self.sniff(False) except TransportError: # If sniffing on failure, it could fail too. Catch the # exception not to interrupt the retries. pass # We either got a response we're happy with or # we've exhausted all of our retries so we return it. if not retry or attempt >= max_retries: return TransportApiResponse(resp.meta, body) else: _logger.warning( "Retrying request after non-successful status %d (attempt %d of %d)", resp.meta.status, attempt, max_retries, ) async def sniff(self, is_initial_sniff: bool = False) -> None: # type: ignore[override] await self._async_call() task = self._create_sniffing_task(is_initial_sniff) # Only block on the task if this is the initial sniff. # Otherwise we do the sniffing in the background. if is_initial_sniff and task: await task async def close(self) -> None: # type: ignore[override] """ Explicitly closes all nodes in the transport's pool """ node: BaseAsyncNode for node in self.node_pool.all(): # type: ignore[assignment] await node.close() def _should_sniff(self, is_initial_sniff: bool) -> bool: """Decide if we should sniff or not. _async_init() must be called before using this function.The async implementation doesn't have a lock. """ if is_initial_sniff: return True # Only start a new sniff if the previous run is completed. if self._sniffing_task: if not self._sniffing_task.done(): return False # If there was a previous run we collect the sniffing task's # result as it could have failed with an exception. self._sniffing_task.result() return ( self._loop.time() - self._last_sniffed_at >= self._min_delay_between_sniffing ) def _create_sniffing_task( self, is_initial_sniff: bool ) -> Optional["asyncio.Task[Any]"]: """Creates a sniffing task if one should be created and returns the task if created.""" task = None if self._should_sniff(is_initial_sniff): _logger.info("Started sniffing for additional nodes") # 'self._sniffing_task' is unset within the task implementation. task = self._loop.create_task(self._sniffing_task_impl(is_initial_sniff)) self._sniffing_task = task return task async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None: """Implementation of the sniffing task""" previously_sniffed_at = self._last_sniffed_at try: self._last_sniffed_at = self._loop.time() options = SniffOptions( is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout ) assert self._sniff_callback is not None node_configs = await await_if_coro(self._sniff_callback(self, options)) if not node_configs and is_initial_sniff: raise SniffingError( "No viable nodes were discovered on the initial sniff attempt" ) prev_node_pool_size = len(self.node_pool) for node_config in node_configs: self.node_pool.add(node_config) # Do some math to log which nodes are new/existing sniffed_nodes = len(node_configs) new_nodes = sniffed_nodes - (len(self.node_pool) - prev_node_pool_size) existing_nodes = sniffed_nodes - new_nodes _logger.debug( "Discovered %d nodes during sniffing (%d new nodes, %d already in pool)", sniffed_nodes, new_nodes, existing_nodes, ) # If sniffing failed for any reason we # want to allow retrying immediately. except BaseException: self._last_sniffed_at = previously_sniffed_at raise async def _async_call(self) -> None: """Async constructor which is called on the first call to perform_request() because we're not guaranteed to be within an active asyncio event loop when __init__() is called. """ if self._loop is not None: return # Call at most once! self._loop = asyncio.get_running_loop() if self._sniff_on_start: await self.sniff(True) elastic-transport-python-8.17.1/elastic_transport/_compat.py000066400000000000000000000064531476450415400243630ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import inspect import sys from pathlib import Path from typing import Any, Awaitable, TypeVar, Union from urllib.parse import quote as _quote from urllib.parse import urlencode, urlparse string_types = (str, bytes) T = TypeVar("T") async def await_if_coro(coro: Union[T, Awaitable[T]]) -> T: if inspect.iscoroutine(coro): return await coro # type: ignore return coro # type: ignore _QUOTE_ALWAYS_SAFE = frozenset( "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-~" ) def quote(string: str, safe: str = "/") -> str: # Redefines 'urllib.parse.quote()' to always have the '~' character # within the 'ALWAYS_SAFE' list. The character was added in Python 3.7 safe = "".join(_QUOTE_ALWAYS_SAFE.union(set(safe))) return _quote(string, safe) try: from threading import Lock except ImportError: class Lock: # type: ignore def __enter__(self) -> None: pass def __exit__(self, *_: Any) -> None: pass def acquire(self, _: bool = True) -> bool: return True def release(self) -> None: pass def warn_stacklevel() -> int: """Dynamically determine warning stacklevel for warnings based on the call stack""" try: # Grab the root module from the current module '__name__' module_name = __name__.partition(".")[0] module_path = Path(sys.modules[module_name].__file__) # type: ignore[arg-type] # If the module is a folder we're looking at # subdirectories, otherwise we're looking for # an exact match. module_is_folder = module_path.name == "__init__.py" if module_is_folder: module_path = module_path.parent # Look through frames until we find a file that # isn't a part of our module, then return that stacklevel. for level, frame in enumerate(inspect.stack()): # Garbage collecting frames frame_filename = Path(frame.filename) del frame if ( # If the module is a folder we look at subdirectory module_is_folder and module_path not in frame_filename.parents ) or ( # Otherwise we're looking for an exact match. not module_is_folder and module_path != frame_filename ): return level except KeyError: pass return 0 __all__ = [ "await_if_coro", "quote", "urlparse", "urlencode", "string_types", "Lock", ] elastic-transport-python-8.17.1/elastic_transport/_exceptions.py000066400000000000000000000074371476450415400252640ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from typing import Any, Tuple from ._models import ApiResponseMeta class TransportWarning(Warning): """Generic warning for the 'elastic-transport' package.""" class SecurityWarning(TransportWarning): """Warning for potentially insecure configurations.""" class TransportError(Exception): """Generic exception for the 'elastic-transport' package. For the 'errors' attribute, errors are ordered from most recently raised (index=0) to least recently raised (index=N) If an HTTP status code is available with the error it will be stored under 'status'. If HTTP headers are available they are stored under 'headers'. """ def __init__(self, message: Any, errors: Tuple[Exception, ...] = ()): super().__init__(message) self.errors = tuple(errors) self.message = message def __repr__(self) -> str: parts = [repr(self.message)] if self.errors: parts.append(f"errors={self.errors!r}") return "{}({})".format(self.__class__.__name__, ", ".join(parts)) def __str__(self) -> str: return str(self.message) class SniffingError(TransportError): """Error that occurs during the sniffing of nodes""" class SerializationError(TransportError): """Error that occurred during the serialization or deserialization of an HTTP message body """ class ConnectionError(TransportError): """Error raised by the HTTP connection""" def __str__(self) -> str: if self.errors: return f"Connection error caused by: {self.errors[0].__class__.__name__}({self.errors[0]})" return "Connection error" class TlsError(ConnectionError): """Error raised by during the TLS handshake""" def __str__(self) -> str: if self.errors: return f"TLS error caused by: {self.errors[0].__class__.__name__}({self.errors[0]})" return "TLS error" class ConnectionTimeout(TransportError): """Connection timed out during an operation""" def __str__(self) -> str: if self.errors: return f"Connection timeout caused by: {self.errors[0].__class__.__name__}({self.errors[0]})" return "Connection timed out" class ApiError(Exception): """Base-class for clients that raise errors due to a response such as '404 Not Found'""" def __init__( self, message: str, meta: ApiResponseMeta, body: Any, errors: Tuple[Exception, ...] = (), ): super().__init__(message) self.message = message self.errors = errors self.meta = meta self.body = body def __repr__(self) -> str: parts = [repr(self.message)] if self.meta: parts.append(f"meta={self.meta!r}") if self.errors: parts.append(f"errors={self.errors!r}") if self.body is not None: parts.append(f"body={self.body!r}") return "{}({})".format(self.__class__.__name__, ", ".join(parts)) def __str__(self) -> str: return f"[{self.meta.status}] {self.message}" elastic-transport-python-8.17.1/elastic_transport/_models.py000066400000000000000000000314111476450415400243530ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import dataclasses import enum import re import ssl from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, Collection, Dict, Iterator, KeysView, Mapping, MutableMapping, Optional, Tuple, TypeVar, Union, ValuesView, ) if TYPE_CHECKING: from typing import Final class DefaultType(enum.Enum): """ Sentinel used as a default value when ``None`` has special meaning like timeouts. The only comparisons that are supported for this type are ``is``. """ value = 0 def __repr__(self) -> str: return "" def __str__(self) -> str: return "" DEFAULT: "Final[DefaultType]" = DefaultType.value T = TypeVar("T") _TYPE_SSL_VERSION = Union[int, ssl.TLSVersion] class HttpHeaders(MutableMapping[str, str]): """HTTP headers Behaves like a Python dictionary. Can be used like this:: headers = HttpHeaders() headers["foo"] = "bar" headers["foo"] = "baz" print(headers["foo"]) # prints "baz" """ __slots__ = ("_internal", "_frozen") def __init__( self, initial: Optional[Union[Mapping[str, str], Collection[Tuple[str, str]]]] = None, ) -> None: self._internal = {} self._frozen = False if initial: for key, val in dict(initial).items(): self._internal[self._normalize_key(key)] = (key, val) def __setitem__(self, key: str, value: str) -> None: if self._frozen: raise ValueError("Can't modify headers that have been frozen") self._internal[self._normalize_key(key)] = (key, value) def __getitem__(self, item: str) -> str: return self._internal[self._normalize_key(item)][1] def __delitem__(self, key: str) -> None: if self._frozen: raise ValueError("Can't modify headers that have been frozen") del self._internal[self._normalize_key(key)] def __eq__(self, other: object) -> bool: if not isinstance(other, Mapping): return NotImplemented if not isinstance(other, HttpHeaders): other = HttpHeaders(other) return {k: v for k, (_, v) in self._internal.items()} == { k: v for k, (_, v) in other._internal.items() } def __ne__(self, other: object) -> bool: if not isinstance(other, Mapping): return NotImplemented return not self == other def __iter__(self) -> Iterator[str]: return iter(self.keys()) def __len__(self) -> int: return len(self._internal) def __bool__(self) -> bool: return bool(self._internal) def __contains__(self, item: object) -> bool: return isinstance(item, str) and self._normalize_key(item) in self._internal def __repr__(self) -> str: return repr(self._dict_hide_auth()) def __str__(self) -> str: return str(self._dict_hide_auth()) def __hash__(self) -> int: if not self._frozen: raise ValueError("Can't calculate the hash of headers that aren't frozen") return hash(tuple((k, v) for k, (_, v) in sorted(self._internal.items()))) def get(self, key: str, default: Optional[str] = None) -> Optional[str]: # type: ignore[override] return self._internal.get(self._normalize_key(key), (None, default))[1] def keys(self) -> KeysView[str]: return self._internal.keys() def values(self) -> ValuesView[str]: return {"": v for _, v in self._internal.values()}.values() def items(self) -> Collection[Tuple[str, str]]: # type: ignore[override] return [(key, val) for _, (key, val) in self._internal.items()] def freeze(self) -> "HttpHeaders": """Freezes the current set of headers so they can be used in hashes. Returns the same instance, doesn't make a copy. """ self._frozen = True return self @property def frozen(self) -> bool: return self._frozen def copy(self) -> "HttpHeaders": return HttpHeaders(self.items()) def _normalize_key(self, key: str) -> str: try: return key.lower() except AttributeError: return key def _dict_hide_auth(self) -> Dict[str, str]: def hide_auth(val: str) -> str: # Hides only the authentication value, not the method. match = re.match(r"^(ApiKey|Basic|Bearer) ", val) if match: return f"{match.group(1)} " return "" return { key: hide_auth(val) if key.lower() == "authorization" else val for key, val in self.items() } @dataclass class ApiResponseMeta: """Metadata that is returned from Transport.perform_request() :ivar int status: HTTP status code :ivar str http_version: HTTP version being used :ivar HttpHeaders headers: HTTP headers :ivar float duration: Number of seconds from start of request to start of response :ivar NodeConfig node: Node which handled the request :ivar typing.Optional[str] mimetype: Mimetype to be used by the serializer to decode the raw response bytes. """ status: int http_version: str headers: HttpHeaders duration: float node: "NodeConfig" @property def mimetype(self) -> Optional[str]: try: content_type = self.headers["content-type"] return content_type.partition(";")[0] or None except KeyError: return None def _empty_frozen_http_headers() -> HttpHeaders: """Used for the 'default_factory' of the 'NodeConfig.headers'""" return HttpHeaders().freeze() @dataclass(repr=True) class NodeConfig: """Configuration options available for every node.""" #: Protocol in use to connect to the node scheme: str #: IP address or hostname to connect to host: str #: IP port to connect to port: int #: Prefix to add to the path of every request path_prefix: str = "" #: Default HTTP headers to add to every request headers: Union[HttpHeaders, Mapping[str, str]] = field( default_factory=_empty_frozen_http_headers ) #: Number of concurrent connections that are #: able to be open at one time for this node. #: Having multiple connections per node allows #: for higher concurrency of requests. connections_per_node: int = 10 #: Number of seconds to wait before a request should timeout. request_timeout: Optional[float] = 10.0 #: Set to ``True`` to enable HTTP compression #: of request and response bodies via gzip. http_compress: Optional[bool] = False #: Set to ``True`` to verify the node's TLS certificate against 'ca_certs' #: Setting to ``False`` will disable verifying the node's certificate. verify_certs: Optional[bool] = True #: Path to a CA bundle or directory containing bundles. By default #: If the ``certifi`` package is installed and ``verify_certs`` is #: set to ``True`` this value will be set to ``certifi.where()``. ca_certs: Optional[str] = None #: Path to a client certificate for TLS client authentication. client_cert: Optional[str] = None #: Path to a client private key for TLS client authentication. client_key: Optional[str] = None #: Hostname or IP address to verify on the node's certificate. #: This is useful if the certificate contains a different value #: than the one supplied in ``host``. An example of this situation #: is connecting to an IP address instead of a hostname. #: Set to ``False`` to disable certificate hostname verification. ssl_assert_hostname: Optional[str] = None #: SHA-256 fingerprint of the node's certificate. If this value is #: given then root-of-trust verification isn't done and only the #: node's certificate fingerprint is verified. #: #: On CPython 3.10+ this also verifies if any certificate in the #: chain including the Root CA matches this fingerprint. However #: because this requires using private APIs support for this is #: **experimental**. ssl_assert_fingerprint: Optional[str] = None #: Minimum TLS version to use to connect to the node. Can be either #: :class:`ssl.TLSVersion` or one of the deprecated #: ``ssl.PROTOCOL_TLSvX`` instances. ssl_version: Optional[_TYPE_SSL_VERSION] = None #: Pre-configured :class:`ssl.SSLContext` object. If this value #: is given then no other TLS options (besides ``ssl_assert_fingerprint``) #: can be set on the :class:`elastic_transport.NodeConfig`. ssl_context: Optional[ssl.SSLContext] = field(default=None, hash=False) #: Set to ``False`` to disable the :class:`elastic_transport.SecurityWarning` #: issued when using ``verify_certs=False``. ssl_show_warn: bool = True #: Extras that can be set to anything, typically used #: for annotating this node with additional information for #: future decisions like sniffing, instance roles, etc. #: Third-party keys should start with an underscore and prefix. _extras: Dict[str, Any] = field(default_factory=dict, hash=False) def replace(self, **kwargs: Any) -> "NodeConfig": if not kwargs: return self return dataclasses.replace(self, **kwargs) def __post_init__(self) -> None: if not isinstance(self.headers, HttpHeaders) or not self.headers.frozen: self.headers = HttpHeaders(self.headers).freeze() if self.scheme != self.scheme.lower(): raise ValueError("'scheme' must be lowercase") if "[" in self.host or "]" in self.host: raise ValueError("'host' must not have square braces") if self.port < 0: raise ValueError("'port' must be a positive integer") if self.connections_per_node <= 0: raise ValueError("'connections_per_node' must be a positive integer") if self.path_prefix: self.path_prefix = ( ("/" + self.path_prefix.strip("/")) if self.path_prefix else "" ) tls_options = [ "ca_certs", "client_cert", "client_key", "ssl_assert_hostname", "ssl_assert_fingerprint", "ssl_context", ] # Disallow setting TLS options on non-HTTPS connections. if self.scheme != "https": if any(getattr(self, attr) is not None for attr in tls_options): raise ValueError("TLS options require scheme to be 'https'") elif self.scheme == "https": # It's not valid to set 'ssl_context' and any other # TLS option, the SSLContext object must be configured # the way the user wants already. def tls_option_filter(attr: object) -> bool: return ( isinstance(attr, str) and attr not in ("ssl_context", "ssl_assert_fingerprint") and getattr(self, attr) is not None ) if self.ssl_context is not None and any( filter( tls_option_filter, tls_options, ) ): raise ValueError( "The 'ssl_context' option can't be combined with other TLS options" ) def __eq__(self, other: object) -> bool: if not isinstance(other, NodeConfig): return NotImplemented return ( self.scheme == other.scheme and self.host == other.host and self.port == other.port and self.path_prefix == other.path_prefix ) def __ne__(self, other: object) -> bool: if not isinstance(other, NodeConfig): return NotImplemented return not self == other def __hash__(self) -> int: return hash( ( self.scheme, self.host, self.port, self.path_prefix, ) ) @dataclass() class SniffOptions: """Options which are passed to Transport.sniff_callback""" is_initial_sniff: bool sniff_timeout: Optional[float] elastic-transport-python-8.17.1/elastic_transport/_node/000077500000000000000000000000001476450415400234435ustar00rootroot00000000000000elastic-transport-python-8.17.1/elastic_transport/_node/__init__.py000066400000000000000000000023021476450415400255510ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from ._base import BaseNode, NodeApiResponse from ._base_async import BaseAsyncNode from ._http_aiohttp import AiohttpHttpNode from ._http_httpx import HttpxAsyncHttpNode from ._http_requests import RequestsHttpNode from ._http_urllib3 import Urllib3HttpNode __all__ = [ "AiohttpHttpNode", "BaseNode", "BaseAsyncNode", "NodeApiResponse", "RequestsHttpNode", "Urllib3HttpNode", "HttpxAsyncHttpNode", ] elastic-transport-python-8.17.1/elastic_transport/_node/_base.py000066400000000000000000000262101476450415400250670ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import logging import os import ssl from typing import Any, ClassVar, List, NamedTuple, Optional, Tuple, Union from .._models import ApiResponseMeta, HttpHeaders, NodeConfig from .._utils import is_ipaddress from .._version import __version__ from ..client_utils import DEFAULT, DefaultType _logger = logging.getLogger("elastic_transport.node") _logger.propagate = False # This logger is very verbose so disable propogation. DEFAULT_CA_CERTS: Optional[str] = None DEFAULT_USER_AGENT = f"elastic-transport-python/{__version__}" RERAISE_EXCEPTIONS = (RecursionError, asyncio.CancelledError) BUILTIN_EXCEPTIONS = ( ValueError, KeyError, NameError, AttributeError, LookupError, AssertionError, IndexError, MemoryError, RuntimeError, SystemError, TypeError, ) HTTP_STATUS_REASONS = { 200: "OK", 201: "Created", 202: "Accepted", 204: "No Content", 205: "Reset Content", 206: "Partial Content", 400: "Bad Request", 401: "Unauthorized", 402: "Payment Required", 403: "Forbidden", 404: "Not Found", 405: "Method Not Allowed", 406: "Not Acceptable", 407: "Proxy Authentication Required", 408: "Request Timeout", 409: "Conflict", 410: "Gone", 411: "Length Required", 412: "Precondition Failed", 413: "Content Too Large", 414: "URI Too Long", 415: "Unsupported Media Type", 429: "Too Many Requests", 500: "Internal Server Error", 501: "Not Implemented", 502: "Bad Gateway", 503: "Service Unavailable", 504: "Gateway Timeout", } try: import certifi DEFAULT_CA_CERTS = certifi.where() except ImportError: # pragma: nocover pass class NodeApiResponse(NamedTuple): meta: ApiResponseMeta body: bytes class BaseNode: """ Class responsible for maintaining a connection to a node. It holds persistent node pool to it and it's main interface (``perform_request``) is thread-safe. :arg config: :class:`~elastic_transport.NodeConfig` instance """ _CLIENT_META_HTTP_CLIENT: ClassVar[Tuple[str, str]] def __init__(self, config: NodeConfig): self._config = config self._headers: HttpHeaders = self.config.headers.copy() # type: ignore[attr-defined] self.headers.setdefault("connection", "keep-alive") self.headers.setdefault("user-agent", DEFAULT_USER_AGENT) self._http_compress = bool(config.http_compress or False) if config.http_compress: self.headers["accept-encoding"] = "gzip" self._scheme = config.scheme self._host = config.host self._port = config.port self._path_prefix = ( ("/" + config.path_prefix.strip("/")) if config.path_prefix else "" ) @property def config(self) -> NodeConfig: return self._config @property def headers(self) -> HttpHeaders: return self._headers @property def scheme(self) -> str: return self._scheme @property def host(self) -> str: return self._host @property def port(self) -> int: return self._port @property def path_prefix(self) -> str: return self._path_prefix def __repr__(self) -> str: return f"<{self.__class__.__name__}({self.base_url})>" def __lt__(self, other: object) -> bool: if not isinstance(other, BaseNode): return NotImplemented return id(self) < id(other) def __eq__(self, other: object) -> bool: if not isinstance(other, BaseNode): return NotImplemented return self.__hash__() == other.__hash__() def __ne__(self, other: object) -> bool: if not isinstance(other, BaseNode): return NotImplemented return not self == other def __hash__(self) -> int: return hash((str(type(self).__name__), self.config)) @property def base_url(self) -> str: return "".join( [ self.scheme, "://", # IPv6 must be wrapped by [...] "[%s]" % self.host if ":" in self.host else self.host, ":%s" % self.port if self.port is not None else "", self.path_prefix, ] ) def perform_request( self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> NodeApiResponse: # pragma: nocover """Constructs and sends an HTTP request and parses the HTTP response. :param method: HTTP method :param target: HTTP request target, typically path+query :param body: Optional HTTP request body encoded as bytes :param headers: Optional HTTP headers to send in addition to the headers already configured. :param request_timeout: Amount of time to wait for the first response bytes to arrive before raising a :class:`elastic_transport.ConnectionTimeout` error. :raises: :class:`elastic_transport.ConnectionError`, :class:`elastic_transport.ConnectionTimeout`, :class:`elastic_transport.TlsError` :rtype: Tuple[ApiResponseMeta, bytes] :returns: Metadata about the request+response and the raw decompressed bytes from the HTTP response body. """ raise NotImplementedError() def close(self) -> None: # pragma: nocover pass def _log_request( self, method: str, target: str, headers: Optional[HttpHeaders], body: Optional[bytes], meta: Optional[ApiResponseMeta] = None, response: Optional[bytes] = None, exception: Optional[Exception] = None, ) -> None: if _logger.hasHandlers(): http_version = meta.http_version if meta else "?.?" lines = ["> %s %s HTTP/%s"] log_args: List[Any] = [method, target, http_version] if headers: for header, value in sorted(headers._dict_hide_auth().items()): lines.append(f"> {header.title()}: {value}") if body is not None: try: body_encoded = body.decode("utf-8", "surrogatepass") except UnicodeError: body_encoded = repr(body) log_args.append(body_encoded) lines.append("> %s") if meta is not None: reason = HTTP_STATUS_REASONS.get(meta.status, None) if reason: lines.append("< HTTP/%s %d %s") log_args.extend((http_version, meta.status, reason)) else: lines.append("< HTTP/%s %d") log_args.extend((http_version, meta.status)) if meta.headers: for header, value in sorted(meta.headers.items()): lines.append(f"< {header.title()}: {value}") if response: try: response_decoded = response.decode("utf-8", "surrogatepass") except UnicodeError: response_decoded = repr(response) log_args.append(response_decoded) lines.append("< %s") if exception is not None: _logger.debug("\n".join(lines), *log_args, exc_info=exception) else: _logger.debug("\n".join(lines), *log_args) _HAS_TLS_VERSION = hasattr(ssl, "TLSVersion") _SSL_PROTOCOL_VERSION_ATTRS = ("TLSv1", "TLSv1_1", "TLSv1_2") _SSL_PROTOCOL_VERSION_DEFAULT = getattr(ssl, "OP_NO_SSLv2", 0) | getattr( ssl, "OP_NO_SSLv3", 0 ) _SSL_PROTOCOL_VERSION_TO_OPTIONS = {} _SSL_PROTOCOL_VERSION_TO_TLS_VERSION = {} for i, _protocol_attr in enumerate(_SSL_PROTOCOL_VERSION_ATTRS): try: _protocol_value = getattr(ssl, f"PROTOCOL_{_protocol_attr}") except AttributeError: continue if _HAS_TLS_VERSION: _tls_version_value = getattr(ssl.TLSVersion, _protocol_attr) _SSL_PROTOCOL_VERSION_TO_TLS_VERSION[_protocol_value] = _tls_version_value _SSL_PROTOCOL_VERSION_TO_TLS_VERSION[_tls_version_value] = _tls_version_value # Because we're setting a minimum version we binary OR all the options together. _SSL_PROTOCOL_VERSION_TO_OPTIONS[_protocol_value] = ( _SSL_PROTOCOL_VERSION_DEFAULT | sum( getattr(ssl, f"OP_NO_{_attr}", 0) for _attr in _SSL_PROTOCOL_VERSION_ATTRS[:i] ) ) # TLSv1.3 is unique, doesn't have a PROTOCOL_TLSvX counterpart. So we have to set it manually. if _HAS_TLS_VERSION: try: _SSL_PROTOCOL_VERSION_TO_TLS_VERSION[ssl.TLSVersion.TLSv1_3] = ( ssl.TLSVersion.TLSv1_3 ) except AttributeError: # pragma: nocover pass def ssl_context_from_node_config(node_config: NodeConfig) -> ssl.SSLContext: if node_config.ssl_context: ctx = node_config.ssl_context else: ctx = ssl.create_default_context() # Enable/disable certificate verification in these orders # to avoid 'ValueErrors' from SSLContext. We only do this # step if the user doesn't pass a preconfigured SSLContext. if node_config.verify_certs: ctx.verify_mode = ssl.CERT_REQUIRED ctx.check_hostname = not is_ipaddress(node_config.host) else: ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE # Enable logging of TLS session keys for use with Wireshark. if hasattr(ctx, "keylog_filename"): sslkeylogfile = os.environ.get("SSLKEYLOGFILE", "") if sslkeylogfile: ctx.keylog_filename = sslkeylogfile # Apply the 'ssl_version' if given, otherwise default to TLSv1.2+ ssl_version = node_config.ssl_version if ssl_version is None: if _HAS_TLS_VERSION: ssl_version = ssl.TLSVersion.TLSv1_2 else: ssl_version = ssl.PROTOCOL_TLSv1_2 try: if _HAS_TLS_VERSION: ctx.minimum_version = _SSL_PROTOCOL_VERSION_TO_TLS_VERSION[ssl_version] else: ctx.options |= _SSL_PROTOCOL_VERSION_TO_OPTIONS[ssl_version] except KeyError: raise ValueError( f"Unsupported value for 'ssl_version': {ssl_version!r}. Must be " "either 'ssl.PROTOCOL_TLSvX' or 'ssl.TLSVersion.TLSvX'" ) from None return ctx elastic-transport-python-8.17.1/elastic_transport/_node/_base_async.py000066400000000000000000000027461476450415400262740ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from typing import Optional, Union from .._models import HttpHeaders from ..client_utils import DEFAULT, DefaultType from ._base import BaseNode, NodeApiResponse class BaseAsyncNode(BaseNode): """Base class for Async HTTP node implementations""" async def perform_request( # type: ignore[override] self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> NodeApiResponse: raise NotImplementedError() # pragma: nocover async def close(self) -> None: # type: ignore[override] raise NotImplementedError() # pragma: nocover elastic-transport-python-8.17.1/elastic_transport/_node/_http_aiohttp.py000066400000000000000000000250071476450415400266670ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import base64 import functools import gzip import os import re import ssl import sys import warnings from typing import Optional, TypedDict, Union from .._compat import warn_stacklevel from .._exceptions import ConnectionError, ConnectionTimeout, SecurityWarning, TlsError from .._models import ApiResponseMeta, HttpHeaders, NodeConfig from ..client_utils import DEFAULT, DefaultType, client_meta_version from ._base import ( BUILTIN_EXCEPTIONS, DEFAULT_CA_CERTS, RERAISE_EXCEPTIONS, NodeApiResponse, ssl_context_from_node_config, ) from ._base_async import BaseAsyncNode try: import aiohttp import aiohttp.client_exceptions as aiohttp_exceptions _AIOHTTP_AVAILABLE = True _AIOHTTP_META_VERSION = client_meta_version(aiohttp.__version__) _version_parts = [] for _version_part in aiohttp.__version__.split(".")[:3]: try: _version_parts.append(int(re.search(r"^([0-9]+)", _version_part).group(1))) # type: ignore[union-attr] except (AttributeError, ValueError): break _AIOHTTP_SEMVER_VERSION = tuple(_version_parts) # See aio-libs/aiohttp#1769 and #5012 _AIOHTTP_FIXED_HEAD_BUG = _AIOHTTP_SEMVER_VERSION >= (3, 7, 0) class RequestKwarg(TypedDict, total=False): ssl: aiohttp.Fingerprint except ImportError: # pragma: nocover _AIOHTTP_AVAILABLE = False _AIOHTTP_META_VERSION = "" _AIOHTTP_FIXED_HEAD_BUG = False # Avoid aiohttp enabled_cleanup_closed warning: https://github.com/aio-libs/aiohttp/pull/9726 _NEEDS_CLEANUP_CLOSED_313 = (3, 13, 0) <= sys.version_info < (3, 13, 1) _NEEDS_CLEANUP_CLOSED = _NEEDS_CLEANUP_CLOSED_313 or sys.version_info < (3, 12, 7) class AiohttpHttpNode(BaseAsyncNode): """Default asynchronous node class using the ``aiohttp`` library via HTTP""" _CLIENT_META_HTTP_CLIENT = ("ai", _AIOHTTP_META_VERSION) def __init__(self, config: NodeConfig): if not _AIOHTTP_AVAILABLE: # pragma: nocover raise ValueError("You must have 'aiohttp' installed to use AiohttpHttpNode") super().__init__(config) self._ssl_assert_fingerprint = config.ssl_assert_fingerprint ssl_context: Optional[ssl.SSLContext] = None if config.scheme == "https": if config.ssl_context is not None: ssl_context = ssl_context_from_node_config(config) else: ssl_context = ssl_context_from_node_config(config) ca_certs = ( DEFAULT_CA_CERTS if config.ca_certs is None else config.ca_certs ) if config.verify_certs: if not ca_certs: raise ValueError( "Root certificates are missing for certificate " "validation. Either pass them in using the ca_certs parameter or " "install certifi to use it automatically." ) else: if config.ssl_show_warn: warnings.warn( f"Connecting to {self.base_url!r} using TLS with verify_certs=False is insecure", stacklevel=warn_stacklevel(), category=SecurityWarning, ) if ca_certs is not None: if os.path.isfile(ca_certs): ssl_context.load_verify_locations(cafile=ca_certs) elif os.path.isdir(ca_certs): ssl_context.load_verify_locations(capath=ca_certs) else: raise ValueError("ca_certs parameter is not a path") # Use client_cert and client_key variables for SSL certificate configuration. if config.client_cert and not os.path.isfile(config.client_cert): raise ValueError("client_cert is not a path to a file") if config.client_key and not os.path.isfile(config.client_key): raise ValueError("client_key is not a path to a file") if config.client_cert and config.client_key: ssl_context.load_cert_chain(config.client_cert, config.client_key) elif config.client_cert: ssl_context.load_cert_chain(config.client_cert) self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment] self.session: Optional[aiohttp.ClientSession] = None # Parameters for creating an aiohttp.ClientSession later. self._connections_per_node = config.connections_per_node self._ssl_context = ssl_context async def perform_request( # type: ignore[override] self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> NodeApiResponse: global _AIOHTTP_FIXED_HEAD_BUG if self.session is None: self._create_aiohttp_session() assert self.session is not None url = self.base_url + target is_head = False # There is a bug in aiohttp<3.7 that disables the re-use # of the connection in the pool when method=HEAD. # See: aio-libs/aiohttp#1769 if method == "HEAD" and not _AIOHTTP_FIXED_HEAD_BUG: method = "GET" is_head = True # total=0 means no timeout for aiohttp resolved_timeout: Optional[float] = ( self.config.request_timeout if request_timeout is DEFAULT else request_timeout ) aiohttp_timeout = aiohttp.ClientTimeout( total=resolved_timeout if resolved_timeout is not None else 0 ) request_headers = self._headers.copy() if headers: request_headers.update(headers) body_to_send: Optional[bytes] if body: if self._http_compress: body_to_send = gzip.compress(body) request_headers["content-encoding"] = "gzip" else: body_to_send = body else: body_to_send = None kwargs: RequestKwarg = {} if self._ssl_assert_fingerprint: kwargs["ssl"] = aiohttp_fingerprint(self._ssl_assert_fingerprint) try: start = self._loop.time() async with self.session.request( method, url, data=body_to_send, headers=request_headers, timeout=aiohttp_timeout, **kwargs, ) as response: if is_head: # We actually called 'GET' so throw away the data. await response.release() raw_data = b"" else: raw_data = await response.read() duration = self._loop.time() - start # We want to reraise a cancellation or recursion error. except RERAISE_EXCEPTIONS: raise except Exception as e: err: Exception if isinstance( e, (asyncio.TimeoutError, aiohttp_exceptions.ServerTimeoutError) ): err = ConnectionTimeout( "Connection timed out during request", errors=(e,) ) elif isinstance(e, (ssl.SSLError, aiohttp_exceptions.ClientSSLError)): err = TlsError(str(e), errors=(e,)) elif isinstance(e, BUILTIN_EXCEPTIONS): raise else: err = ConnectionError(str(e), errors=(e,)) self._log_request( method="HEAD" if is_head else method, target=target, headers=request_headers, body=body, exception=err, ) raise err from None meta = ApiResponseMeta( node=self.config, duration=duration, http_version="1.1", status=response.status, headers=HttpHeaders(response.headers), ) self._log_request( method="HEAD" if is_head else method, target=target, headers=request_headers, body=body, meta=meta, response=raw_data, ) return NodeApiResponse( meta, raw_data, ) async def close(self) -> None: # type: ignore[override] if self.session: await self.session.close() self.session = None def _create_aiohttp_session(self) -> None: """Creates an aiohttp.ClientSession(). This is delayed until the first call to perform_request() so that AsyncTransport has a chance to set AiohttpHttpNode.loop """ if self._loop is None: self._loop = asyncio.get_running_loop() self.session = aiohttp.ClientSession( headers=self.headers, skip_auto_headers=("accept", "accept-encoding", "user-agent"), auto_decompress=True, loop=self._loop, cookie_jar=aiohttp.DummyCookieJar(), connector=aiohttp.TCPConnector( limit_per_host=self._connections_per_node, use_dns_cache=True, enable_cleanup_closed=_NEEDS_CLEANUP_CLOSED, ssl=self._ssl_context or False, ), ) @functools.lru_cache(maxsize=64, typed=True) def aiohttp_fingerprint(ssl_assert_fingerprint: str) -> "aiohttp.Fingerprint": """Changes 'ssl_assert_fingerprint' into a configured 'aiohttp.Fingerprint' instance. Uses a cache to prevent creating tons of objects needlessly. """ return aiohttp.Fingerprint( base64.b16decode(ssl_assert_fingerprint.replace(":", ""), casefold=True) ) elastic-transport-python-8.17.1/elastic_transport/_node/_http_httpx.py000066400000000000000000000170051476450415400263650ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import gzip import os.path import ssl import time import warnings from typing import Literal, Optional, Union from .._compat import warn_stacklevel from .._exceptions import ConnectionError, ConnectionTimeout, SecurityWarning, TlsError from .._models import ApiResponseMeta, HttpHeaders, NodeConfig from ..client_utils import DEFAULT, DefaultType, client_meta_version from ._base import ( BUILTIN_EXCEPTIONS, DEFAULT_CA_CERTS, RERAISE_EXCEPTIONS, NodeApiResponse, ssl_context_from_node_config, ) from ._base_async import BaseAsyncNode try: import httpx _HTTPX_AVAILABLE = True _HTTPX_META_VERSION = client_meta_version(httpx.__version__) except ImportError: _HTTPX_AVAILABLE = False _HTTPX_META_VERSION = "" class HttpxAsyncHttpNode(BaseAsyncNode): _CLIENT_META_HTTP_CLIENT = ("hx", _HTTPX_META_VERSION) def __init__(self, config: NodeConfig): if not _HTTPX_AVAILABLE: # pragma: nocover raise ValueError("You must have 'httpx' installed to use HttpxNode") super().__init__(config) if config.ssl_assert_fingerprint: raise ValueError( "httpx does not support certificate pinning. https://github.com/encode/httpx/issues/761" ) ssl_context: Union[ssl.SSLContext, Literal[False]] = False if config.scheme == "https": if config.ssl_context is not None: ssl_context = ssl_context_from_node_config(config) else: ssl_context = ssl_context_from_node_config(config) ca_certs = ( DEFAULT_CA_CERTS if config.ca_certs is None else config.ca_certs ) if config.verify_certs: if not ca_certs: raise ValueError( "Root certificates are missing for certificate " "validation. Either pass them in using the ca_certs parameter or " "install certifi to use it automatically." ) else: if config.ssl_show_warn: warnings.warn( f"Connecting to {self.base_url!r} using TLS with verify_certs=False is insecure", stacklevel=warn_stacklevel(), category=SecurityWarning, ) if ca_certs is not None: if os.path.isfile(ca_certs): ssl_context.load_verify_locations(cafile=ca_certs) elif os.path.isdir(ca_certs): ssl_context.load_verify_locations(capath=ca_certs) else: raise ValueError("ca_certs parameter is not a path") # Use client_cert and client_key variables for SSL certificate configuration. if config.client_cert and not os.path.isfile(config.client_cert): raise ValueError("client_cert is not a path to a file") if config.client_key and not os.path.isfile(config.client_key): raise ValueError("client_key is not a path to a file") if config.client_cert and config.client_key: ssl_context.load_cert_chain(config.client_cert, config.client_key) elif config.client_cert: ssl_context.load_cert_chain(config.client_cert) self.client = httpx.AsyncClient( base_url=f"{config.scheme}://{config.host}:{config.port}", limits=httpx.Limits(max_connections=config.connections_per_node), verify=ssl_context or False, timeout=config.request_timeout, ) async def perform_request( # type: ignore[override] self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> NodeApiResponse: resolved_headers = self._headers.copy() if headers: resolved_headers.update(headers) if body: if self._http_compress: resolved_body = gzip.compress(body) resolved_headers["content-encoding"] = "gzip" else: resolved_body = body else: resolved_body = None try: start = time.perf_counter() if request_timeout is DEFAULT: resp = await self.client.request( method, target, content=resolved_body, headers=dict(resolved_headers), ) else: resp = await self.client.request( method, target, content=resolved_body, headers=dict(resolved_headers), timeout=request_timeout, ) response_body = resp.read() duration = time.perf_counter() - start except RERAISE_EXCEPTIONS + BUILTIN_EXCEPTIONS: raise except Exception as e: err: Exception if isinstance(e, (TimeoutError, httpx.TimeoutException)): err = ConnectionTimeout( "Connection timed out during request", errors=(e,) ) elif isinstance(e, ssl.SSLError): err = TlsError(str(e), errors=(e,)) # Detect SSL errors for httpx v0.28.0+ # Needed until https://github.com/encode/httpx/issues/3350 is fixed elif isinstance(e, httpx.ConnectError) and e.__cause__: context = e.__cause__.__context__ if isinstance(context, ssl.SSLError): err = TlsError(str(context), errors=(e,)) else: err = ConnectionError(str(e), errors=(e,)) else: err = ConnectionError(str(e), errors=(e,)) self._log_request( method=method, target=target, headers=resolved_headers, body=body, exception=err, ) raise err from None meta = ApiResponseMeta( resp.status_code, resp.http_version, HttpHeaders(resp.headers), duration, self.config, ) self._log_request( method=method, target=target, headers=resolved_headers, body=body, meta=meta, response=response_body, ) return NodeApiResponse(meta, response_body) async def close(self) -> None: # type: ignore[override] await self.client.aclose() elastic-transport-python-8.17.1/elastic_transport/_node/_http_requests.py000066400000000000000000000245371476450415400271010ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import gzip import ssl import time import warnings from typing import Any, Optional, Union import urllib3 from .._compat import warn_stacklevel from .._exceptions import ConnectionError, ConnectionTimeout, SecurityWarning, TlsError from .._models import ApiResponseMeta, HttpHeaders, NodeConfig from ..client_utils import DEFAULT, DefaultType, client_meta_version from ._base import ( BUILTIN_EXCEPTIONS, RERAISE_EXCEPTIONS, BaseNode, NodeApiResponse, ssl_context_from_node_config, ) try: import requests from requests.adapters import HTTPAdapter from requests.auth import AuthBase _REQUESTS_AVAILABLE = True _REQUESTS_META_VERSION = client_meta_version(requests.__version__) # Use our custom HTTPSConnectionPool for chain cert fingerprint support. try: from ._urllib3_chain_certs import HTTPSConnectionPool except (ImportError, AttributeError): HTTPSConnectionPool = urllib3.HTTPSConnectionPool # type: ignore[assignment,misc] class _ElasticHTTPAdapter(HTTPAdapter): def __init__(self, node_config: NodeConfig, **kwargs: Any) -> None: self._node_config = node_config super().__init__(**kwargs) def init_poolmanager( self, connections: Any, maxsize: int, block: bool = False, **pool_kwargs: Any, ) -> None: if self._node_config.scheme == "https": ssl_context = ssl_context_from_node_config(self._node_config) pool_kwargs.setdefault("ssl_context", ssl_context) # Fingerprint verification doesn't require CA certificates being loaded. # We also want to disable other verification methods as we only care # about the fingerprint of the certificates, not whether they form # a verified chain to a trust anchor. if self._node_config.ssl_assert_fingerprint: # Manually disable these in the right order on the SSLContext # so urllib3 won't think we want conflicting things. ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE pool_kwargs["assert_fingerprint"] = ( self._node_config.ssl_assert_fingerprint ) pool_kwargs["cert_reqs"] = "CERT_NONE" pool_kwargs["assert_hostname"] = False super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs) # type: ignore [no-untyped-call] self.poolmanager.pool_classes_by_scheme["https"] = HTTPSConnectionPool except ImportError: # pragma: nocover _REQUESTS_AVAILABLE = False _REQUESTS_META_VERSION = "" class RequestsHttpNode(BaseNode): """Synchronous node using the ``requests`` library communicating via HTTP. Supports setting :attr:`requests.Session.auth` via the :attr:`elastic_transport.NodeConfig._extras` using the ``requests.session.auth`` key. """ _CLIENT_META_HTTP_CLIENT = ("rq", _REQUESTS_META_VERSION) def __init__(self, config: NodeConfig): if not _REQUESTS_AVAILABLE: # pragma: nocover raise ValueError( "You must have 'requests' installed to use RequestsHttpNode" ) super().__init__(config) # Initialize Session so .headers works before calling super().__init__(). self.session = requests.Session() self.session.headers.clear() # Empty out all the default session headers if config.scheme == "https": # If we're using ssl_assert_fingerprint we don't want # to verify certificates the typical way. Instead we # rely on the custom ElasticHTTPAdapter and urllib3. if config.ssl_assert_fingerprint: self.session.verify = False # Otherwise we go the traditional route of verifying certs. else: if config.ca_certs: if not config.verify_certs: raise ValueError( "You cannot use 'ca_certs' when 'verify_certs=False'" ) self.session.verify = config.ca_certs else: self.session.verify = config.verify_certs if not config.ssl_show_warn: urllib3.disable_warnings() if ( config.scheme == "https" and not config.verify_certs and config.ssl_show_warn ): warnings.warn( f"Connecting to {self.base_url!r} using TLS with verify_certs=False is insecure", stacklevel=warn_stacklevel(), category=SecurityWarning, ) # Requests supports setting 'session.auth' via _extras['requests.session.auth'] = ... try: requests_session_auth: Optional[AuthBase] = config._extras.pop( "requests.session.auth", None ) except AttributeError: requests_session_auth = None if requests_session_auth is not None: self.session.auth = requests_session_auth # Client certificates if config.client_cert: if config.client_key: self.session.cert = (config.client_cert, config.client_key) else: self.session.cert = config.client_cert # Create and mount custom adapter for constraining number of connections adapter = _ElasticHTTPAdapter( node_config=config, pool_connections=config.connections_per_node, pool_maxsize=config.connections_per_node, pool_block=True, ) # Preload the HTTPConnectionPool so initialization issues # are raised here instead of in perform_request() if hasattr(adapter, "get_connection_with_tls_context"): request = requests.Request(method="GET", url=self.base_url) prepared_request = self.session.prepare_request(request) adapter.get_connection_with_tls_context( prepared_request, verify=self.session.verify ) else: # elastic-transport is not vulnerable to CVE-2024-35195 because it uses # requests.Session and an SSLContext without using the verify parameter. # We should remove this branch when requiring requests 2.32 or later. adapter.get_connection(self.base_url) self.session.mount(prefix=f"{self.scheme}://", adapter=adapter) def perform_request( self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> NodeApiResponse: url = self.base_url + target headers = HttpHeaders(headers or ()) request_headers = self._headers.copy() if headers: request_headers.update(headers) body_to_send: Optional[bytes] if body: if self._http_compress: body_to_send = gzip.compress(body) request_headers["content-encoding"] = "gzip" else: body_to_send = body else: body_to_send = None start = time.time() request = requests.Request( method=method, headers=request_headers, url=url, data=body_to_send ) prepared_request = self.session.prepare_request(request) send_kwargs = { "timeout": ( request_timeout if request_timeout is not DEFAULT else self.config.request_timeout ) } send_kwargs.update( self.session.merge_environment_settings( # type: ignore[arg-type] prepared_request.url, {}, None, None, None ) ) try: response = self.session.send(prepared_request, **send_kwargs) # type: ignore[arg-type] data = response.content duration = time.time() - start response_headers = HttpHeaders(response.headers) except RERAISE_EXCEPTIONS: raise except Exception as e: err: Exception if isinstance(e, requests.Timeout): err = ConnectionTimeout( "Connection timed out during request", errors=(e,) ) elif isinstance(e, (ssl.SSLError, requests.exceptions.SSLError)): err = TlsError(str(e), errors=(e,)) elif isinstance(e, BUILTIN_EXCEPTIONS): raise else: err = ConnectionError(str(e), errors=(e,)) self._log_request( method=method, target=target, headers=request_headers, body=body, exception=err, ) raise err from None meta = ApiResponseMeta( node=self.config, duration=duration, http_version="1.1", status=response.status_code, headers=response_headers, ) self._log_request( method=method, target=target, headers=request_headers, body=body, meta=meta, response=data, ) return NodeApiResponse( meta, data, ) def close(self) -> None: """ Explicitly closes connections """ self.session.close() elastic-transport-python-8.17.1/elastic_transport/_node/_http_urllib3.py000066400000000000000000000176731476450415400266050ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import gzip import ssl import time import warnings from typing import Any, Dict, Optional, Union try: from importlib import metadata except ImportError: import importlib_metadata as metadata # type: ignore[no-redef] import urllib3 from urllib3.exceptions import ConnectTimeoutError, NewConnectionError, ReadTimeoutError from urllib3.util.retry import Retry from .._compat import warn_stacklevel from .._exceptions import ConnectionError, ConnectionTimeout, SecurityWarning, TlsError from .._models import ApiResponseMeta, HttpHeaders, NodeConfig from ..client_utils import DEFAULT, DefaultType, client_meta_version from ._base import ( BUILTIN_EXCEPTIONS, DEFAULT_CA_CERTS, RERAISE_EXCEPTIONS, BaseNode, NodeApiResponse, ssl_context_from_node_config, ) try: from ._urllib3_chain_certs import HTTPSConnectionPool except (ImportError, AttributeError): HTTPSConnectionPool = urllib3.HTTPSConnectionPool # type: ignore[assignment,misc] class Urllib3HttpNode(BaseNode): """Default synchronous node class using the ``urllib3`` library via HTTP""" _CLIENT_META_HTTP_CLIENT = ("ur", client_meta_version(metadata.version("urllib3"))) def __init__(self, config: NodeConfig): super().__init__(config) pool_class = urllib3.HTTPConnectionPool kw: Dict[str, Any] = {} if config.scheme == "https": pool_class = HTTPSConnectionPool ssl_context = ssl_context_from_node_config(config) kw["ssl_context"] = ssl_context if config.ssl_assert_hostname and config.ssl_assert_fingerprint: raise ValueError( "Can't specify both 'ssl_assert_hostname' and 'ssl_assert_fingerprint'" ) # Fingerprint verification doesn't require CA certificates being loaded. # We also want to disable other verification methods as we only care # about the fingerprint of the certificates, not whether they form # a verified chain to a trust anchor. elif config.ssl_assert_fingerprint: # Manually disable these in the right order on the SSLContext # so urllib3 won't think we want conflicting things. ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE kw.update( { "assert_fingerprint": config.ssl_assert_fingerprint, "assert_hostname": False, "cert_reqs": "CERT_NONE", } ) else: kw["assert_hostname"] = config.ssl_assert_hostname # Convert all sentinel values to their actual default # values if not using an SSLContext. ca_certs = ( DEFAULT_CA_CERTS if config.ca_certs is None else config.ca_certs ) if config.verify_certs: if not ca_certs: raise ValueError( "Root certificates are missing for certificate " "validation. Either pass them in using the ca_certs parameter or " "install certifi to use it automatically." ) kw.update( { "cert_reqs": "CERT_REQUIRED", "ca_certs": ca_certs, "cert_file": config.client_cert, "key_file": config.client_key, } ) else: kw["cert_reqs"] = "CERT_NONE" if config.ssl_show_warn: warnings.warn( f"Connecting to {self.base_url!r} using TLS with verify_certs=False is insecure", stacklevel=warn_stacklevel(), category=SecurityWarning, ) else: urllib3.disable_warnings() self.pool = pool_class( config.host, port=config.port, timeout=urllib3.Timeout(total=config.request_timeout), maxsize=config.connections_per_node, block=True, **kw, ) def perform_request( self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> NodeApiResponse: if self.path_prefix: target = f"{self.path_prefix}{target}" start = time.time() try: kw = {} if request_timeout is not DEFAULT: kw["timeout"] = request_timeout request_headers = self._headers.copy() if headers: request_headers.update(headers) body_to_send: Optional[bytes] if body: if self._http_compress: body_to_send = gzip.compress(body) request_headers["content-encoding"] = "gzip" else: body_to_send = body else: body_to_send = None response = self.pool.urlopen( method, target, body=body_to_send, retries=Retry(False), headers=request_headers, **kw, # type: ignore[arg-type] ) response_headers = HttpHeaders(response.headers) data = response.data duration = time.time() - start except RERAISE_EXCEPTIONS: raise except Exception as e: err: Exception if isinstance(e, NewConnectionError): err = ConnectionError(str(e), errors=(e,)) elif isinstance(e, (ConnectTimeoutError, ReadTimeoutError)): err = ConnectionTimeout( "Connection timed out during request", errors=(e,) ) elif isinstance(e, (ssl.SSLError, urllib3.exceptions.SSLError)): err = TlsError(str(e), errors=(e,)) elif isinstance(e, BUILTIN_EXCEPTIONS): raise else: err = ConnectionError(str(e), errors=(e,)) self._log_request( method=method, target=target, headers=request_headers, body=body, exception=err, ) raise err from e meta = ApiResponseMeta( node=self.config, duration=duration, http_version="1.1", status=response.status, headers=response_headers, ) self._log_request( method=method, target=target, headers=request_headers, body=body, meta=meta, response=data, ) return NodeApiResponse( meta, data, ) def close(self) -> None: """ Explicitly closes connection """ self.pool.close() elastic-transport-python-8.17.1/elastic_transport/_node/_urllib3_chain_certs.py000066400000000000000000000143621476450415400301000ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import hashlib import sys from binascii import hexlify, unhexlify from hmac import compare_digest from typing import Any, List, Optional import _ssl # type: ignore import urllib3 import urllib3.connection from ._base import RERAISE_EXCEPTIONS if sys.version_info < (3, 10) or sys.implementation.name != "cpython": raise ImportError("Only supported on CPython 3.10+") _ENCODING_DER: int = _ssl.ENCODING_DER _HASHES_BY_LENGTH = {32: hashlib.md5, 40: hashlib.sha1, 64: hashlib.sha256} __all__ = ["HTTPSConnectionPool"] class HTTPSConnection(urllib3.connection.HTTPSConnection): def __init__(self, *args: Any, **kwargs: Any) -> None: self._elastic_assert_fingerprint: Optional[str] = None super().__init__(*args, **kwargs) def connect(self) -> None: super().connect() # Hack to prevent a warning within HTTPSConnectionPool._validate_conn() if self._elastic_assert_fingerprint: self.is_verified = True class HTTPSConnectionPool(urllib3.HTTPSConnectionPool): ConnectionCls = HTTPSConnection """HTTPSConnectionPool implementation which supports ``assert_fingerprint`` on certificates within the chain instead of only the leaf cert using private APIs in CPython 3.10+ """ def __init__( self, *args: Any, assert_fingerprint: Optional[str] = None, **kwargs: Any ) -> None: self._elastic_assert_fingerprint = ( assert_fingerprint.replace(":", "").lower() if assert_fingerprint else None ) # Complain about fingerprint length earlier than urllib3 does. if ( self._elastic_assert_fingerprint and len(self._elastic_assert_fingerprint) not in _HASHES_BY_LENGTH ): valid_lengths = "', '".join(map(str, sorted(_HASHES_BY_LENGTH.keys()))) raise ValueError( f"Fingerprint of invalid length '{len(self._elastic_assert_fingerprint)}'" f", should be one of '{valid_lengths}'" ) if self._elastic_assert_fingerprint: # Skip fingerprinting by urllib3 as we'll do it ourselves kwargs["assert_fingerprint"] = None super().__init__(*args, **kwargs) def _new_conn(self) -> HTTPSConnection: """ Return a fresh :class:`urllib3.connection.HTTPSConnection`. """ conn: HTTPSConnection = super()._new_conn() # type: ignore[assignment] # Tell our custom connection if we'll assert fingerprint ourselves conn._elastic_assert_fingerprint = self._elastic_assert_fingerprint return conn def _validate_conn(self, conn: HTTPSConnection) -> None: # type: ignore[override] """ Called right before a request is made, after the socket is created. """ super(HTTPSConnectionPool, self)._validate_conn(conn) if self._elastic_assert_fingerprint: hash_func = _HASHES_BY_LENGTH[len(self._elastic_assert_fingerprint)] assert_fingerprint = unhexlify( self._elastic_assert_fingerprint.lower() .replace(":", "") .encode("ascii") ) fingerprints: List[bytes] try: if sys.version_info >= (3, 13): fingerprints = [ hash_func(cert).digest() for cert in conn.sock.get_verified_chain() # type: ignore ] else: # 'get_verified_chain()' and 'Certificate.public_bytes()' are private APIs # in CPython 3.10. They're not documented anywhere yet but seem to work # and we need them for Security on by Default so... onwards we go! # See: https://github.com/python/cpython/pull/25467 fingerprints = [ hash_func(cert.public_bytes(_ENCODING_DER)).digest() for cert in conn.sock._sslobj.get_verified_chain() # type: ignore[union-attr] ] except RERAISE_EXCEPTIONS: # pragma: nocover raise # Because these are private APIs we are super careful here # so that if anything "goes wrong" we fallback on the old behavior. except Exception: # pragma: nocover fingerprints = [] # Only add the peercert in front of the chain if it's not there for some reason. # This is to make sure old behavior of 'ssl_assert_fingerprint' still works. peercert_fingerprint = hash_func(conn.sock.getpeercert(True)).digest() # type: ignore[union-attr] if peercert_fingerprint not in fingerprints: # pragma: nocover fingerprints.insert(0, peercert_fingerprint) # If any match then that's a success! We always run them # all through though because of constant time concerns. success = False for fingerprint in fingerprints: success |= compare_digest(fingerprint, assert_fingerprint) # Give users all the fingerprints we checked against in # order of peer -> root CA. if not success: raise urllib3.exceptions.SSLError( 'Fingerprints did not match. Expected "{0}", got "{1}".'.format( self._elastic_assert_fingerprint, '", "'.join([x.decode() for x in map(hexlify, fingerprints)]), ) ) conn.is_verified = success elastic-transport-python-8.17.1/elastic_transport/_node_pool.py000066400000000000000000000352221476450415400250520ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import logging import random import threading import time from collections import defaultdict from queue import Empty, PriorityQueue from typing import ( TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, overload, ) from ._compat import Lock from ._models import NodeConfig from ._node import BaseNode if TYPE_CHECKING: from typing import Literal _logger = logging.getLogger("elastic_transport.node_pool") class NodeSelector: """ Simple class used to select a node from a list of currently live node instances. In init time it is passed a dictionary containing all the nodes options which it can then use during the selection process. When the ``select()`` method is called it is given a list of *currently* live nodes to choose from. The selector is initialized with the list of seed nodes that the NodePool was initialized with. This list of seed nodes can be used to make decisions within ``select()`` Example of where this would be useful is a zone-aware selector that would only select connections from it's own zones and only fall back to other connections where there would be none in its zones. """ def __init__(self, node_configs: List[NodeConfig]): """ :arg node_configs: List of NodeConfig instances """ self.node_configs = node_configs def select(self, nodes: Sequence[BaseNode]) -> BaseNode: # pragma: nocover """ Select a nodes from the given list. :arg nodes: list of live nodes to choose from """ raise NotImplementedError() class RandomSelector(NodeSelector): """Randomly select a node""" def select(self, nodes: Sequence[BaseNode]) -> BaseNode: return random.choice(nodes) class RoundRobinSelector(NodeSelector): """Select a node using round-robin""" def __init__(self, node_configs: List[NodeConfig]): super().__init__(node_configs) self._thread_local = threading.local() def select(self, nodes: Sequence[BaseNode]) -> BaseNode: self._thread_local.rr = (getattr(self._thread_local, "rr", -1) + 1) % len(nodes) return nodes[self._thread_local.rr] _SELECTOR_CLASS_NAMES: Dict[str, Type[NodeSelector]] = { "round_robin": RoundRobinSelector, "random": RandomSelector, } class NodePool: """ Container holding the :class:`~elastic_transport.BaseNode` instances, managing the selection process (via a :class:`~elastic_transport.NodeSelector`) and dead connections. It's only interactions are with the :class:`~elastic_transport.Transport` class that drives all the actions within ``NodePool``. Initially nodes are stored on the class as a list and, along with the connection options, get passed to the ``NodeSelector`` instance for future reference. Upon each request the ``Transport`` will ask for a ``BaseNode`` via the ``get_node`` method. If the connection fails (it's `perform_request` raises a `ConnectionError`) it will be marked as dead (via `mark_dead`) and put on a timeout (if it fails N times in a row the timeout is exponentially longer - the formula is `default_timeout * 2 ** (fail_count - 1)`). When the timeout is over the connection will be resurrected and returned to the live pool. A connection that has been previously marked as dead and succeeds will be marked as live (its fail count will be deleted). """ def __init__( self, node_configs: List[NodeConfig], node_class: Type[BaseNode], dead_node_backoff_factor: float = 1.0, max_dead_node_backoff: float = 30.0, node_selector_class: Union[str, Type[NodeSelector]] = RoundRobinSelector, randomize_nodes: bool = True, ): """ :arg node_configs: List of initial NodeConfigs to use :arg node_class: Type to use when creating nodes :arg dead_node_backoff_factor: Number of seconds used as a factor in calculating the amount of "backoff" time we should give a node after an unsuccessful request. The formula is calculated as follows where N is the number of consecutive failures: ``min(dead_backoff_factor * (2 ** (N - 1)), max_dead_backoff)`` :arg max_dead_node_backoff: Maximum number of seconds to wait when calculating the "backoff" time for a dead node. :arg node_selector_class: :class:`~elastic_transport.NodeSelector` subclass to use if more than one connection is live :arg randomize_nodes: shuffle the list of nodes upon instantiation to avoid dog-piling effect across processes """ if not node_configs: raise ValueError("Must specify at least one NodeConfig") node_configs = list( node_configs ) # Make a copy so we don't have side-effects outside. if any(not isinstance(node_config, NodeConfig) for node_config in node_configs): raise TypeError("NodePool must be passed a list of NodeConfig instances") if isinstance(node_selector_class, str): if node_selector_class not in _SELECTOR_CLASS_NAMES: raise ValueError( "Unknown option for selector_class: '%s'. " "Available options are: '%s'" % ( node_selector_class, "', '".join(sorted(_SELECTOR_CLASS_NAMES.keys())), ) ) node_selector_class = _SELECTOR_CLASS_NAMES[node_selector_class] if randomize_nodes: # randomize the list of nodes to avoid hammering the same node # if a large set of clients are created all at once. random.shuffle(node_configs) # Initial set of nodes that the NodePool was initialized with. # This set of nodes can never be removed. self._seed_nodes: Tuple[NodeConfig, ...] = tuple(set(node_configs)) if len(self._seed_nodes) != len(node_configs): raise ValueError("Cannot use duplicate NodeConfigs within a NodePool") self._node_class = node_class self._node_selector = node_selector_class(node_configs) # _all_nodes relies on dict insert order self._all_nodes: Dict[NodeConfig, BaseNode] = {} for node_config in node_configs: self._all_nodes[node_config] = self._node_class(node_config) # Lock that is used to protect writing to 'all_nodes' self._all_nodes_write_lock = Lock() # Flag which tells NodePool.get() that there's only one node # which allows for optimizations. Setting this flag is also # protected by the above write lock. self._all_nodes_len_1 = len(self._all_nodes) == 1 # Collection of currently-alive nodes. This is an ordered # dict so round-robin actually works. self._alive_nodes: Dict[NodeConfig, BaseNode] = dict(self._all_nodes) # PriorityQueue for thread safety and ease of timeout management self._dead_nodes: PriorityQueue[Tuple[float, BaseNode]] = PriorityQueue() self._dead_consecutive_failures: Dict[NodeConfig, int] = defaultdict(int) # Nodes that have been marked as 'removed' to be thread-safe. self._removed_nodes: Set[NodeConfig] = set() # default timeout after which to try resurrecting a connection self._dead_node_backoff_factor = dead_node_backoff_factor self._max_dead_node_backoff = max_dead_node_backoff @property def node_class(self) -> Type[BaseNode]: return self._node_class @property def node_selector(self) -> NodeSelector: return self._node_selector @property def dead_node_backoff_factor(self) -> float: return self._dead_node_backoff_factor @property def max_dead_node_backoff(self) -> float: return self._max_dead_node_backoff def mark_dead(self, node: BaseNode, _now: Optional[float] = None) -> None: """ Mark the node as dead (failed). Remove it from the live pool and put it on a timeout. :arg node: The failed node. """ now: float = _now if _now is not None else time.time() try: del self._alive_nodes[node.config] except KeyError: pass consecutive_failures = self._dead_consecutive_failures[node.config] + 1 self._dead_consecutive_failures[node.config] = consecutive_failures try: timeout = min( self._dead_node_backoff_factor * (2 ** (consecutive_failures - 1)), self._max_dead_node_backoff, ) except OverflowError: timeout = self._max_dead_node_backoff self._dead_nodes.put((now + timeout, node)) _logger.warning( "Node %r has failed for %i times in a row, putting on %i second timeout", node, consecutive_failures, timeout, ) def mark_live(self, node: BaseNode) -> None: """ Mark node as healthy after a resurrection. Resets the fail counter for the node. :arg node: The ``BaseNode`` instance to mark as alive. """ try: del self._dead_consecutive_failures[node.config] except KeyError: # race condition, safe to ignore pass else: self._alive_nodes.setdefault(node.config, node) _logger.warning( "Node %r has been marked alive after a successful request", node, ) @overload def resurrect(self, force: "Literal[True]" = ...) -> BaseNode: ... @overload def resurrect(self, force: "Literal[False]" = ...) -> Optional[BaseNode]: ... def resurrect(self, force: bool = False) -> Optional[BaseNode]: """ Attempt to resurrect a node from the dead queue. It will try to locate one (not all) eligible (it's timeout is over) node to return to the live pool. Any resurrected node is also returned. :arg force: resurrect a node even if there is none eligible (used when we have no live nodes). If force is 'True'' resurrect always returns a node. """ node: Optional[BaseNode] mark_node_alive_after: float = 0.0 try: # Try to resurrect a dead node if any. mark_node_alive_after, node = self._dead_nodes.get(block=False) except Empty: # No dead nodes. if force: # If we're being forced to return a node we randomly # pick between alive and dead nodes. return random.choice(list(self._all_nodes.values())) node = None if node is not None and not force and mark_node_alive_after > time.time(): # return it back if not eligible and not forced self._dead_nodes.put((mark_node_alive_after, node)) node = None # either we were forced or the node is eligible to be retried if node is not None: self._alive_nodes[node.config] = node _logger.info("Resurrected node %r (force=%s)", node, force) return node def add(self, node_config: NodeConfig) -> None: try: # If the node was previously removed we mark it as "in the pool" self._removed_nodes.remove(node_config) except KeyError: pass with self._all_nodes_write_lock: # We don't error when trying to add a duplicate node # to the pool because threading+sniffing can call # .add() on the same NodeConfig. if node_config not in self._all_nodes: node = self._node_class(node_config) self._all_nodes[node.config] = node # Update the flag to disable optimizations. Also ensures that # .resurrect() starts getting called so our added node makes # it way into the alive nodes. self._all_nodes_len_1 = False # Start the node as dead because 'dead_nodes' is thread-safe. # The node will be resurrected on the next call to .get() self._dead_consecutive_failures[node.config] = 0 self._dead_nodes.put((time.time(), node)) def remove(self, node_config: NodeConfig) -> None: # Can't mark a seed node as removed. if node_config not in self._seed_nodes: self._removed_nodes.add(node_config) def get(self) -> BaseNode: """ Return a node from the pool using the ``NodeSelector`` instance. It tries to resurrect eligible nodes, forces a resurrection when no nodes are available and passes the list of live nodes to the selector instance to choose from. """ # Even with the optimization below we want to participate in the # dead/alive cycle in case more nodes join after sniffing, for example. self.resurrect() # Flag that short-circuits the extra logic if we have only one node. # The only way this flag can be set to 'True' is if there were only # one node defined within 'seed_nodes' so we know this good to do. if self._all_nodes_len_1: return self._all_nodes[self._seed_nodes[0]] # Filter nodes in 'alive_nodes' to ones not marked as removed. nodes = [ node for node_config, node in self._alive_nodes.items() if node_config not in self._removed_nodes ] # No live nodes, resurrect one by force and return it if not nodes: return self.resurrect(force=True) # Only call selector if we have a choice to make if len(nodes) > 1: return self._node_selector.select(nodes) return nodes[0] def all(self) -> List[BaseNode]: return list(self._all_nodes.values()) def __repr__(self) -> str: return "" def __len__(self) -> int: return len(self._all_nodes) elastic-transport-python-8.17.1/elastic_transport/_otel.py000066400000000000000000000056111476450415400240360ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations from typing import TYPE_CHECKING, Literal, Mapping if TYPE_CHECKING: from opentelemetry.trace import Span # A list of the Elasticsearch endpoints that qualify as "search" endpoints. The search query in # the request body may be captured for these endpoints, depending on the body capture strategy. SEARCH_ENDPOINTS = ( "search", "async_search.submit", "msearch", "eql.search", "esql.query", "terms_enum", "search_template", "msearch_template", "render_search_template", ) class OpenTelemetrySpan: def __init__( self, otel_span: Span | None, endpoint_id: str | None = None, body_strategy: Literal["omit", "raw"] = "omit", ): self.otel_span = otel_span self.body_strategy = body_strategy self.endpoint_id = endpoint_id def set_node_metadata( self, host: str, port: int, base_url: str, target: str ) -> None: if self.otel_span is None: return # url.full does not contain auth info which is passed as headers self.otel_span.set_attribute("url.full", base_url + target) self.otel_span.set_attribute("server.address", host) self.otel_span.set_attribute("server.port", port) def set_elastic_cloud_metadata(self, headers: Mapping[str, str]) -> None: if self.otel_span is None: return cluster_name = headers.get("X-Found-Handling-Cluster") if cluster_name is not None: self.otel_span.set_attribute("db.elasticsearch.cluster.name", cluster_name) node_name = headers.get("X-Found-Handling-Instance") if node_name is not None: self.otel_span.set_attribute("db.elasticsearch.node.name", node_name) def set_db_statement(self, serialized_body: bytes) -> None: if self.otel_span is None: return if self.body_strategy == "omit": return elif self.body_strategy == "raw" and self.endpoint_id in SEARCH_ENDPOINTS: self.otel_span.set_attribute( "db.statement", serialized_body.decode("utf-8") ) elastic-transport-python-8.17.1/elastic_transport/_response.py000066400000000000000000000144051476450415400247320ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from typing import ( Any, Dict, Generic, Iterator, List, NoReturn, Tuple, TypeVar, Union, overload, ) from ._models import ApiResponseMeta _BodyType = TypeVar("_BodyType") _ObjectBodyType = TypeVar("_ObjectBodyType") _ListItemBodyType = TypeVar("_ListItemBodyType") class ApiResponse(Generic[_BodyType]): """Base class for all API response classes""" __slots__ = ("_body", "_meta") def __init__( self, *args: Any, **kwargs: Any, ): def _raise_typeerror() -> NoReturn: raise TypeError("Must pass 'meta' and 'body' to ApiResponse") from None # Working around pre-releases of elasticsearch-python # that would use raw=... instead of body=... try: if bool(args) == bool(kwargs): _raise_typeerror() elif args and len(args) == 2: body, meta = args elif kwargs and "raw" in kwargs: body = kwargs.pop("raw") meta = kwargs.pop("meta") kwargs.pop("body_cls", None) elif kwargs and "body" in kwargs: body = kwargs.pop("body") meta = kwargs.pop("meta") kwargs.pop("body_cls", None) else: _raise_typeerror() except KeyError: _raise_typeerror() # If there are still kwargs left over # and we're not in positional mode... if not args and kwargs: _raise_typeerror() self._body = body self._meta = meta def __repr__(self) -> str: return f"{type(self).__name__}({self.body!r})" def __contains__(self, item: Any) -> bool: return item in self._body def __eq__(self, other: object) -> bool: if isinstance(other, ApiResponse): other = other.body return self._body == other # type: ignore[no-any-return] def __ne__(self, other: object) -> bool: if isinstance(other, ApiResponse): other = other.body return self._body != other # type: ignore[no-any-return] def __getitem__(self, item: Any) -> Any: return self._body[item] def __getattr__(self, attr: str) -> Any: return getattr(self._body, attr) def __getstate__(self) -> Tuple[_BodyType, ApiResponseMeta]: return self._body, self._meta def __setstate__(self, state: Tuple[_BodyType, ApiResponseMeta]) -> None: self._body, self._meta = state def __len__(self) -> int: return len(self._body) def __iter__(self) -> Iterator[Any]: return iter(self._body) def __str__(self) -> str: return str(self._body) def __bool__(self) -> bool: return bool(self._body) @property def meta(self) -> ApiResponseMeta: """Response metadata""" return self._meta # type: ignore[no-any-return] @property def body(self) -> _BodyType: """User-friendly view into the raw response with type hints if applicable""" return self._body # type: ignore[no-any-return] @property def raw(self) -> _BodyType: return self.body class TextApiResponse(ApiResponse[str]): """API responses which are text such as 'text/plain' or 'text/csv'""" def __iter__(self) -> Iterator[str]: return iter(self.body) def __getitem__(self, item: Union[int, slice]) -> str: return self.body[item] @property def body(self) -> str: return self._body # type: ignore[no-any-return] class BinaryApiResponse(ApiResponse[bytes]): """API responses which are a binary response such as Mapbox vector tiles""" def __iter__(self) -> Iterator[int]: return iter(self.body) @overload def __getitem__(self, item: slice) -> bytes: ... @overload def __getitem__(self, item: int) -> int: ... def __getitem__(self, item: Union[int, slice]) -> Union[int, bytes]: return self.body[item] @property def body(self) -> bytes: return self._body # type: ignore[no-any-return] class HeadApiResponse(ApiResponse[bool]): """API responses which are for an 'exists' / HEAD API request""" def __init__(self, meta: ApiResponseMeta): super().__init__(body=200 <= meta.status < 300, meta=meta) def __bool__(self) -> bool: return 200 <= self.meta.status < 300 @property def body(self) -> bool: return bool(self) class ObjectApiResponse(Generic[_ObjectBodyType], ApiResponse[Dict[str, Any]]): """API responses which are for a JSON object""" def __getitem__(self, item: str) -> Any: return self.body[item] # type: ignore[index] def __iter__(self) -> Iterator[str]: return iter(self._body) @property def body(self) -> _ObjectBodyType: # type: ignore[override] return self._body # type: ignore[no-any-return] class ListApiResponse( Generic[_ListItemBodyType], ApiResponse[List[Any]], ): """API responses which are a list of items. Can be NDJSON or a JSON list""" @overload def __getitem__(self, item: slice) -> List[_ListItemBodyType]: ... @overload def __getitem__(self, item: int) -> _ListItemBodyType: ... def __getitem__( self, item: Union[int, slice] ) -> Union[_ListItemBodyType, List[_ListItemBodyType]]: return self.body[item] def __iter__(self) -> Iterator[_ListItemBodyType]: return iter(self.body) @property def body(self) -> List[_ListItemBodyType]: return self._body # type: ignore[no-any-return] elastic-transport-python-8.17.1/elastic_transport/_serializer.py000066400000000000000000000206471476450415400252520ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import json import re import uuid from datetime import date from decimal import Decimal from typing import Any, ClassVar, Mapping, Optional from ._exceptions import SerializationError try: import orjson except ModuleNotFoundError: orjson = None # type: ignore[assignment] class Serializer: """Serializer interface.""" mimetype: ClassVar[str] def loads(self, data: bytes) -> Any: # pragma: nocover raise NotImplementedError() def dumps(self, data: Any) -> bytes: # pragma: nocover raise NotImplementedError() class TextSerializer(Serializer): """Text serializer to and from UTF-8.""" mimetype: ClassVar[str] = "text/*" def loads(self, data: bytes) -> str: if isinstance(data, str): return data try: return data.decode("utf-8", "surrogatepass") except UnicodeError as e: raise SerializationError( f"Unable to deserialize as text: {data!r}", errors=(e,) ) def dumps(self, data: str) -> bytes: # The body is already encoded to bytes # so we forward the request body along. if isinstance(data, bytes): return data try: return data.encode("utf-8", "surrogatepass") except (AttributeError, UnicodeError, TypeError) as e: raise SerializationError( f"Unable to serialize to text: {data!r}", errors=(e,) ) class JsonSerializer(Serializer): """JSON serializer relying on the standard library json module.""" mimetype: ClassVar[str] = "application/json" def default(self, data: Any) -> Any: if isinstance(data, date): return data.isoformat() elif isinstance(data, uuid.UUID): return str(data) elif isinstance(data, Decimal): return float(data) raise SerializationError( message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})", ) def json_dumps(self, data: Any) -> bytes: return json.dumps( data, default=self.default, ensure_ascii=False, separators=(",", ":") ).encode("utf-8", "surrogatepass") def json_loads(self, data: bytes) -> Any: return json.loads(data) def loads(self, data: bytes) -> Any: # Sometimes responses use Content-Type: json but actually # don't contain any data. We should return something instead # of erroring in these cases. if data == b"": return None try: return self.json_loads(data) except (ValueError, TypeError) as e: raise SerializationError( message=f"Unable to deserialize as JSON: {data!r}", errors=(e,) ) def dumps(self, data: Any) -> bytes: # The body is already encoded to bytes # so we forward the request body along. if isinstance(data, str): return data.encode("utf-8", "surrogatepass") elif isinstance(data, bytes): return data try: return self.json_dumps(data) # This should be captured by the .default() # call but just in case we also wrap these. except (ValueError, UnicodeError, TypeError) as e: # pragma: nocover raise SerializationError( message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})", errors=(e,), ) if orjson is not None: class OrjsonSerializer(JsonSerializer): """JSON serializer relying on the orjson package. Only available if orjson if installed. It is faster, especially for vectors, but is also stricter. """ def json_dumps(self, data: Any) -> bytes: return orjson.dumps( data, default=self.default, option=orjson.OPT_SERIALIZE_NUMPY ) def json_loads(self, data: bytes) -> Any: return orjson.loads(data) class NdjsonSerializer(JsonSerializer): """Newline delimited JSON (NDJSON) serializer relying on the standard library json module.""" mimetype: ClassVar[str] = "application/x-ndjson" def loads(self, data: bytes) -> Any: ndjson = [] for line in re.split(b"[\n\r]", data): if not line: continue try: ndjson.append(self.json_loads(line)) except (ValueError, TypeError) as e: raise SerializationError( message=f"Unable to deserialize as NDJSON: {data!r}", errors=(e,) ) return ndjson def dumps(self, data: Any) -> bytes: # The body is already encoded to bytes # so we forward the request body along. if isinstance(data, (bytes, str)): data = (data,) buffer = bytearray() for line in data: if isinstance(line, str): line = line.encode("utf-8", "surrogatepass") if isinstance(line, bytes): buffer += line # Ensure that there is always a final newline if not line.endswith(b"\n"): buffer += b"\n" else: try: buffer += self.json_dumps(line) buffer += b"\n" # This should be captured by the .default() # call but just in case we also wrap these. except (ValueError, UnicodeError, TypeError) as e: # pragma: nocover raise SerializationError( message=f"Unable to serialize to NDJSON: {data!r} (type: {type(data).__name__})", errors=(e,), ) return bytes(buffer) DEFAULT_SERIALIZERS = { JsonSerializer.mimetype: JsonSerializer(), TextSerializer.mimetype: TextSerializer(), NdjsonSerializer.mimetype: NdjsonSerializer(), } class SerializerCollection: """Collection of serializers that can be fetched by mimetype. Used by :class:`elastic_transport.Transport` to serialize and deserialize native Python types into bytes before passing to a node. """ def __init__( self, serializers: Optional[Mapping[str, Serializer]] = None, default_mimetype: str = "application/json", ): if serializers is None: serializers = DEFAULT_SERIALIZERS try: self.default_serializer = serializers[default_mimetype] except KeyError: raise ValueError( f"Must configure a serializer for the default mimetype {default_mimetype!r}" ) from None self.serializers = dict(serializers) def dumps(self, data: Any, mimetype: Optional[str] = None) -> bytes: return self.get_serializer(mimetype).dumps(data) def loads(self, data: bytes, mimetype: Optional[str] = None) -> Any: return self.get_serializer(mimetype).loads(data) def get_serializer(self, mimetype: Optional[str]) -> Serializer: # split out charset if mimetype is None: serializer = self.default_serializer else: mimetype, _, _ = mimetype.partition(";") try: serializer = self.serializers[mimetype] except KeyError: # Try for '/*' types after the specific type fails. try: mimetype_supertype = mimetype.partition("/")[0] serializer = self.serializers[f"{mimetype_supertype}/*"] except KeyError: raise SerializationError( f"Unknown mimetype, not able to serialize or deserialize: {mimetype}" ) from None return serializer elastic-transport-python-8.17.1/elastic_transport/_transport.py000066400000000000000000000560071476450415400251340ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import dataclasses import inspect import logging import time import warnings from platform import python_version from typing import ( Any, Callable, Collection, Dict, List, Mapping, NamedTuple, Optional, Tuple, Type, Union, cast, ) from ._compat import Lock, warn_stacklevel from ._exceptions import ( ConnectionError, ConnectionTimeout, SniffingError, TransportError, TransportWarning, ) from ._models import ( DEFAULT, ApiResponseMeta, DefaultType, HttpHeaders, NodeConfig, SniffOptions, ) from ._node import ( AiohttpHttpNode, BaseNode, HttpxAsyncHttpNode, RequestsHttpNode, Urllib3HttpNode, ) from ._node_pool import NodePool, NodeSelector from ._otel import OpenTelemetrySpan from ._serializer import DEFAULT_SERIALIZERS, Serializer, SerializerCollection from ._version import __version__ from .client_utils import client_meta_version, resolve_default # Allows for using a node_class by name rather than import. NODE_CLASS_NAMES: Dict[str, Type[BaseNode]] = { "urllib3": Urllib3HttpNode, "requests": RequestsHttpNode, "aiohttp": AiohttpHttpNode, "httpxasync": HttpxAsyncHttpNode, } # These are HTTP status errors that shouldn't be considered # 'errors' for marking a node as dead. These errors typically # mean everything is fine server-wise and instead the API call # in question responded successfully. NOT_DEAD_NODE_HTTP_STATUSES = {None, 400, 401, 402, 403, 404, 409} DEFAULT_CLIENT_META_SERVICE = ("et", client_meta_version(__version__)) _logger = logging.getLogger("elastic_transport.transport") class TransportApiResponse(NamedTuple): meta: ApiResponseMeta body: Any class Transport: """ Encapsulation of transport-related to logic. Handles instantiation of the individual nodes as well as creating a node pool to hold them. Main interface is the :meth:`elastic_transport.Transport.perform_request` method. """ def __init__( self, node_configs: List[NodeConfig], node_class: Union[str, Type[BaseNode]] = Urllib3HttpNode, node_pool_class: Type[NodePool] = NodePool, randomize_nodes_in_pool: bool = True, node_selector_class: Optional[Union[str, Type[NodeSelector]]] = None, dead_node_backoff_factor: Optional[float] = None, max_dead_node_backoff: Optional[float] = None, serializers: Optional[Mapping[str, Serializer]] = None, default_mimetype: str = "application/json", max_retries: int = 3, retry_on_status: Collection[int] = (429, 502, 503, 504), retry_on_timeout: bool = False, sniff_on_start: bool = False, sniff_before_requests: bool = False, sniff_on_node_failure: bool = False, sniff_timeout: Optional[float] = 0.5, min_delay_between_sniffing: float = 10.0, sniff_callback: Optional[ Callable[ ["Transport", "SniffOptions"], Union[List[NodeConfig], List[NodeConfig]], ] ] = None, meta_header: bool = True, client_meta_service: Tuple[str, str] = DEFAULT_CLIENT_META_SERVICE, ): """ :arg node_configs: List of 'NodeConfig' instances to create initial set of nodes. :arg node_class: subclass of :class:`~elastic_transport.BaseNode` to use or the name of the Connection (ie 'urllib3', 'requests') :arg node_pool_class: subclass of :class:`~elastic_transport.NodePool` to use :arg randomize_nodes_in_pool: Set to false to not randomize nodes within the pool. Defaults to true. :arg node_selector_class: Class to be used to select nodes within the :class:`~elastic_transport.NodePool`. :arg dead_node_backoff_factor: Exponential backoff factor to calculate the amount of time to timeout a node after an unsuccessful API call. :arg max_dead_node_backoff: Maximum amount of time to timeout a node after an unsuccessful API call. :arg serializers: optional dict of serializer instances that will be used for deserializing data coming from the server. (key is the mimetype) :arg max_retries: Maximum number of retries for an API call. Set to 0 to disable retries. Defaults to ``0``. :arg retry_on_status: set of HTTP status codes on which we should retry on a different node. defaults to ``(429, 502, 503, 504)`` :arg retry_on_timeout: should timeout trigger a retry on different node? (default ``False``) :arg sniff_on_start: If ``True`` will sniff for additional nodes as soon as possible, guaranteed before the first request. :arg sniff_on_node_failure: If ``True`` will sniff for additional nodees after a node is marked as dead in the pool. :arg sniff_before_requests: If ``True`` will occasionally sniff for additional nodes as requests are sent. :arg sniff_timeout: Timeout value in seconds to use for sniffing requests. Defaults to 1 second. :arg min_delay_between_sniffing: Number of seconds to wait between calls to :meth:`elastic_transport.Transport.sniff` to avoid sniffing too frequently. Defaults to 10 seconds. :arg sniff_callback: Function that is passed a :class:`elastic_transport.Transport` and :class:`elastic_transport.SniffOptions` and should do node discovery and return a list of :class:`elastic_transport.NodeConfig` instances. :arg meta_header: If set to False the ``X-Elastic-Client-Meta`` HTTP header won't be sent. Defaults to True. :arg client_meta_service: Key-value pair for the service field of the client metadata header. Defaults to the service key-value for Elastic Transport. """ if isinstance(node_class, str): if node_class not in NODE_CLASS_NAMES: options = "', '".join(sorted(NODE_CLASS_NAMES.keys())) raise ValueError( f"Unknown option for node_class: '{node_class}'. " f"Available options are: '{options}'" ) node_class = NODE_CLASS_NAMES[node_class] # Verify that the node_class we're passed is # async/sync the same as the transport is. is_transport_async = inspect.iscoroutinefunction(self.perform_request) is_node_async = inspect.iscoroutinefunction(node_class.perform_request) if is_transport_async != is_node_async: raise ValueError( f"Specified 'node_class' {'is' if is_node_async else 'is not'} async, " f"should be {'async' if is_transport_async else 'sync'} instead" ) validate_sniffing_options( node_configs=node_configs, sniff_on_start=sniff_on_start, sniff_before_requests=sniff_before_requests, sniff_on_node_failure=sniff_on_node_failure, sniff_callback=sniff_callback, ) # Create the default metadata for the x-elastic-client-meta # HTTP header. Only requires adding the (service, service_version) # tuple to the beginning of the client_meta self._transport_client_meta: Tuple[Tuple[str, str], ...] = ( client_meta_service, ("py", client_meta_version(python_version())), ("t", client_meta_version(__version__)), ) # Grab the 'HTTP_CLIENT_META' property from the node class http_client_meta = cast( Optional[Tuple[str, str]], getattr(node_class, "_CLIENT_META_HTTP_CLIENT", None), ) if http_client_meta: self._transport_client_meta += (http_client_meta,) if not isinstance(meta_header, bool): raise TypeError("'meta_header' must be of type bool") self.meta_header = meta_header # serialization config _serializers = DEFAULT_SERIALIZERS.copy() # if custom serializers map has been supplied, override the defaults with it if serializers: _serializers.update(serializers) # Create our collection of serializers self.serializers = SerializerCollection( _serializers, default_mimetype=default_mimetype ) # Set of default request options self.max_retries = max_retries self.retry_on_status = retry_on_status self.retry_on_timeout = retry_on_timeout # Build the NodePool from all the options node_pool_kwargs: Dict[str, Any] = {} if node_selector_class is not None: node_pool_kwargs["node_selector_class"] = node_selector_class if dead_node_backoff_factor is not None: node_pool_kwargs["dead_node_backoff_factor"] = dead_node_backoff_factor if max_dead_node_backoff is not None: node_pool_kwargs["max_dead_node_backoff"] = max_dead_node_backoff self.node_pool: NodePool = node_pool_class( node_configs, node_class=node_class, randomize_nodes=randomize_nodes_in_pool, **node_pool_kwargs, ) self._sniff_on_start = sniff_on_start self._sniff_before_requests = sniff_before_requests self._sniff_on_node_failure = sniff_on_node_failure self._sniff_timeout = sniff_timeout self._sniff_callback = sniff_callback self._sniffing_lock = Lock() # Used to track whether we're currently sniffing. self._min_delay_between_sniffing = min_delay_between_sniffing self._last_sniffed_at = 0.0 if sniff_on_start: self.sniff(True) def perform_request( # type: ignore[return] self, method: str, target: str, *, body: Optional[Any] = None, headers: Union[Mapping[str, Any], DefaultType] = DEFAULT, max_retries: Union[int, DefaultType] = DEFAULT, retry_on_status: Union[Collection[int], DefaultType] = DEFAULT, retry_on_timeout: Union[bool, DefaultType] = DEFAULT, request_timeout: Union[Optional[float], DefaultType] = DEFAULT, client_meta: Union[Tuple[Tuple[str, str], ...], DefaultType] = DEFAULT, otel_span: Union[OpenTelemetrySpan, DefaultType] = DEFAULT, ) -> TransportApiResponse: """ Perform the actual request. Retrieve a node from the node pool, pass all the information to it's perform_request method and return the data. If an exception was raised, mark the node as failed and retry (up to ``max_retries`` times). If the operation was successful and the node used was previously marked as dead, mark it as live, resetting it's failure count. :arg method: HTTP method to use :arg target: HTTP request target :arg body: body of the request, will be serialized using serializer and passed to the node :arg headers: Additional headers to send with the request. :arg max_retries: Maximum number of retries before giving up on a request. Set to ``0`` to disable retries. :arg retry_on_status: Collection of HTTP status codes to retry. :arg retry_on_timeout: Set to true to retry after timeout errors. :arg request_timeout: Amount of time to wait for a response to fail with a timeout error. :arg client_meta: Extra client metadata key-value pairs to send in the client meta header. :arg otel_span: OpenTelemetry span used to add metadata to the span. :returns: Tuple of the :class:`elastic_transport.ApiResponseMeta` with the deserialized response. """ if headers is DEFAULT: request_headers = HttpHeaders() else: request_headers = HttpHeaders(headers) max_retries = resolve_default(max_retries, self.max_retries) retry_on_timeout = resolve_default(retry_on_timeout, self.retry_on_timeout) retry_on_status = resolve_default(retry_on_status, self.retry_on_status) otel_span = resolve_default(otel_span, OpenTelemetrySpan(None)) if self.meta_header: request_headers["x-elastic-client-meta"] = ",".join( f"{k}={v}" for k, v in self._transport_client_meta + resolve_default(client_meta, ()) ) # Serialize the request body to bytes based on the given mimetype. request_body: Optional[bytes] if body is not None: if "content-type" not in request_headers: raise ValueError( "Must provide a 'Content-Type' header to requests with bodies" ) request_body = self.serializers.dumps( body, mimetype=request_headers["content-type"] ) otel_span.set_db_statement(request_body) else: request_body = None # Errors are stored from (oldest->newest) errors: List[Exception] = [] for attempt in range(max_retries + 1): # If we sniff before requests are made we want to do so before # 'node_pool.get()' is called so our sniffed nodes show up in the pool. if self._sniff_before_requests: self.sniff(False) retry = False node_failure = False last_response: Optional[TransportApiResponse] = None node = self.node_pool.get() start_time = time.time() try: otel_span.set_node_metadata(node.host, node.port, node.base_url, target) resp = node.perform_request( method, target, body=request_body, headers=request_headers, request_timeout=request_timeout, ) _logger.info( "%s %s%s [status:%s duration:%.3fs]" % ( method, node.base_url, target, resp.meta.status, time.time() - start_time, ) ) if method != "HEAD": body = self.serializers.loads(resp.body, resp.meta.mimetype) else: body = None if resp.meta.status in retry_on_status: retry = True # Keep track of the last response we see so we can return # it in case the retried request returns with a transport error. last_response = TransportApiResponse(resp.meta, body) except TransportError as e: _logger.info( "%s %s%s [status:%s duration:%.3fs]" % ( method, node.base_url, target, "N/A", time.time() - start_time, ) ) if isinstance(e, ConnectionTimeout): retry = retry_on_timeout node_failure = True elif isinstance(e, ConnectionError): retry = True node_failure = True # If the error was determined to be a node failure # we mark it dead in the node pool to allow for # other nodes to be retried. if node_failure: self.node_pool.mark_dead(node) if self._sniff_on_node_failure: try: self.sniff(False) except TransportError: # If sniffing on failure, it could fail too. Catch the # exception not to interrupt the retries. pass if not retry or attempt >= max_retries: # Since we're exhausted but we have previously # received some sort of response from the API # we should forward that along instead of the # transport error. Likely to be more actionable. if last_response is not None: return last_response e.errors = tuple(errors) raise else: _logger.warning( "Retrying request after failure (attempt %d of %d)", attempt, max_retries, exc_info=e, ) errors.append(e) else: # If we got back a response we need to check if that status # is indicative of a healthy node even if it's a non-2XX status if ( 200 <= resp.meta.status < 299 or resp.meta.status in NOT_DEAD_NODE_HTTP_STATUSES ): self.node_pool.mark_live(node) else: self.node_pool.mark_dead(node) if self._sniff_on_node_failure: try: self.sniff(False) except TransportError: # If sniffing on failure, it could fail too. Catch the # exception not to interrupt the retries. pass # We either got a response we're happy with or # we've exhausted all of our retries so we return it. if not retry or attempt >= max_retries: return TransportApiResponse(resp.meta, body) else: _logger.warning( "Retrying request after non-successful status %d (attempt %d of %d)", resp.meta.status, attempt, max_retries, ) def sniff(self, is_initial_sniff: bool = False) -> None: previously_sniffed_at = self._last_sniffed_at should_sniff = self._should_sniff(is_initial_sniff) try: if should_sniff: _logger.info("Started sniffing for additional nodes") self._last_sniffed_at = time.time() options = SniffOptions( is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout ) assert self._sniff_callback is not None node_configs = self._sniff_callback(self, options) if not node_configs and is_initial_sniff: raise SniffingError( "No viable nodes were discovered on the initial sniff attempt" ) prev_node_pool_size = len(self.node_pool) for node_config in node_configs: self.node_pool.add(node_config) # Do some math to log which nodes are new/existing sniffed_nodes = len(node_configs) new_nodes = sniffed_nodes - (len(self.node_pool) - prev_node_pool_size) existing_nodes = sniffed_nodes - new_nodes _logger.debug( "Discovered %d nodes during sniffing (%d new nodes, %d already in pool)", sniffed_nodes, new_nodes, existing_nodes, ) # If sniffing failed for any reason we # want to allow retrying immediately. except Exception as e: _logger.warning("Encountered an error during sniffing", exc_info=e) self._last_sniffed_at = previously_sniffed_at raise # If we started a sniff we need to release the lock. finally: if should_sniff: self._sniffing_lock.release() def close(self) -> None: """ Explicitly closes all nodes in the transport's pool """ for node in self.node_pool.all(): node.close() def _should_sniff(self, is_initial_sniff: bool) -> bool: """Decide if we should sniff or not. If we return ``True`` from this method the caller has a responsibility to unlock the ``_sniffing_lock`` """ if not is_initial_sniff and ( time.time() - self._last_sniffed_at < self._min_delay_between_sniffing ): return False return self._sniffing_lock.acquire(False) def validate_sniffing_options( *, node_configs: List[NodeConfig], sniff_before_requests: bool, sniff_on_start: bool, sniff_on_node_failure: bool, sniff_callback: Optional[Any], ) -> None: """Validates the Transport configurations for sniffing""" sniffing_enabled = sniff_before_requests or sniff_on_start or sniff_on_node_failure if sniffing_enabled and not sniff_callback: raise ValueError("Enabling sniffing requires specifying a 'sniff_callback'") if not sniffing_enabled and sniff_callback: raise ValueError( "Using 'sniff_callback' requires enabling sniffing via 'sniff_on_start', " "'sniff_before_requests' or 'sniff_on_node_failure'" ) # If we're sniffing we want to warn the user for non-homogenous NodeConfigs. if sniffing_enabled and len(node_configs) > 1: warn_if_varying_node_config_options(node_configs) def warn_if_varying_node_config_options(node_configs: List[NodeConfig]) -> None: """Function which detects situations when sniffing may produce incorrect configs""" exempt_attrs = {"host", "port", "connections_per_node", "_extras", "ssl_context"} match_attr_dict = None for node_config in node_configs: attr_dict = { field.name: getattr(node_config, field.name) for field in dataclasses.fields(node_config) if field.name not in exempt_attrs } if match_attr_dict is None: match_attr_dict = attr_dict # Detected two nodes that have different config, warn the user. elif match_attr_dict != attr_dict: warnings.warn( "Detected NodeConfig instances with different options. " "It's recommended to keep all options except for " "'host' and 'port' the same for sniffing to work reliably.", category=TransportWarning, stacklevel=warn_stacklevel(), ) elastic-transport-python-8.17.1/elastic_transport/_utils.py000066400000000000000000000064371476450415400242420ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import re from typing import Any, Dict, Union def fixup_module_metadata(module_name: str, namespace: Dict[str, Any]) -> None: # Yoinked from python-trio/outcome, thanks Nathaniel! License: MIT def fix_one(obj: Any) -> None: mod = getattr(obj, "__module__", None) if mod is not None and mod.startswith("elastic_transport."): obj.__module__ = module_name if isinstance(obj, type): for attr_value in obj.__dict__.values(): fix_one(attr_value) for objname in namespace["__all__"]: obj = namespace[objname] fix_one(obj) IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" IPV4_RE = re.compile("^" + IPV4_PAT + "$") HEX_PAT = "[0-9A-Fa-f]{1,4}" LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT) _subs = {"hex": HEX_PAT, "ls32": LS32_PAT} _variations = [ # 6( h16 ":" ) ls32 "(?:%(hex)s:){6}%(ls32)s", # "::" 5( h16 ":" ) ls32 "::(?:%(hex)s:){5}%(ls32)s", # [ h16 ] "::" 4( h16 ":" ) ls32 "(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s", # [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32 "(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s", # [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32 "(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s", # [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32 "(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s", # [ *4( h16 ":" ) h16 ] "::" ls32 "(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s", # [ *5( h16 ":" ) h16 ] "::" h16 "(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s", # [ *6( h16 ":" ) h16 ] "::" "(?:(?:%(hex)s:){0,6}%(hex)s)?::", ] IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")" UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._!\-~" ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+" BRACELESS_IPV6_ADDRZ_PAT = IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?" BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + BRACELESS_IPV6_ADDRZ_PAT + "$") def is_ipaddress(hostname: Union[str, bytes]) -> bool: """Detects whether the hostname given is an IPv4 or IPv6 address. Also detects IPv6 addresses with Zone IDs. """ # Copied from urllib3. License: MIT if isinstance(hostname, bytes): # IDN A-label bytes are ASCII compatible. hostname = hostname.decode("ascii") hostname = hostname.strip("[]") return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname)) elastic-transport-python-8.17.1/elastic_transport/_version.py000066400000000000000000000014531476450415400245600ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. __version__ = "8.17.1" elastic-transport-python-8.17.1/elastic_transport/client_utils.py000066400000000000000000000200431476450415400254260ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import base64 import binascii import dataclasses import re import urllib.parse from platform import python_version from typing import Optional, Tuple, TypeVar, Union from urllib.parse import quote as _quote from urllib3.exceptions import LocationParseError from urllib3.util import parse_url from ._models import DEFAULT, DefaultType, NodeConfig from ._utils import fixup_module_metadata from ._version import __version__ __all__ = [ "CloudId", "DEFAULT", "DefaultType", "basic_auth_to_header", "client_meta_version", "create_user_agent", "dataclasses", "parse_cloud_id", "percent_encode", "resolve_default", "to_bytes", "to_str", "url_to_node_config", ] T = TypeVar("T") def resolve_default(val: Union[DefaultType, T], default: T) -> T: """Resolves a value that could be the ``DEFAULT`` sentinel into either the given value or the default value. """ return val if val is not DEFAULT else default def create_user_agent(name: str, version: str) -> str: """Creates the 'User-Agent' header given the library name and version""" return ( f"{name}/{version} (Python/{python_version()}; elastic-transport/{__version__})" ) def client_meta_version(version: str) -> str: """Converts a Python version into a version string compatible with the ``X-Elastic-Client-Meta`` HTTP header. """ match = re.match(r"^([0-9][0-9.]*[0-9]|[0-9])(.*)$", version) if match is None: raise ValueError( "Version {version!r} not formatted like a Python version string" ) version, version_suffix = match.groups() # Don't treat post-releases as pre-releases. if re.search(r"^\.post[0-9]*$", version_suffix): return version if version_suffix: version += "p" return version @dataclasses.dataclass(frozen=True, repr=True) class CloudId: #: Name of the cluster in Elastic Cloud cluster_name: str #: Host and port of the Elasticsearch instance es_address: Optional[Tuple[str, int]] #: Host and port of the Kibana instance kibana_address: Optional[Tuple[str, int]] def parse_cloud_id(cloud_id: str) -> CloudId: """Parses an Elastic Cloud ID into its components""" try: cloud_id = to_str(cloud_id) cluster_name, _, cloud_id = cloud_id.partition(":") parts = to_str(binascii.a2b_base64(to_bytes(cloud_id, "ascii")), "ascii").split( "$" ) parent_dn = parts[0] if not parent_dn: raise ValueError() # Caught and re-raised properly below es_uuid: Optional[str] kibana_uuid: Optional[str] try: es_uuid = parts[1] except IndexError: es_uuid = None try: kibana_uuid = parts[2] or None except IndexError: kibana_uuid = None if ":" in parent_dn: parent_dn, _, parent_port = parent_dn.rpartition(":") port = int(parent_port) else: port = 443 except (ValueError, IndexError, UnicodeError): raise ValueError("Cloud ID is not properly formatted") from None es_host = f"{es_uuid}.{parent_dn}" if es_uuid else None kibana_host = f"{kibana_uuid}.{parent_dn}" if kibana_uuid else None return CloudId( cluster_name=cluster_name, es_address=(es_host, port) if es_host else None, kibana_address=(kibana_host, port) if kibana_host else None, ) def to_str( value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict" ) -> str: if isinstance(value, bytes): return value.decode(encoding, errors) return value def to_bytes( value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict" ) -> bytes: if isinstance(value, str): return value.encode(encoding, errors) return value # Python 3.7 added '~' to the safe list for urllib.parse.quote() _QUOTE_ALWAYS_SAFE = frozenset( "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-~" ) def percent_encode( string: Union[bytes, str], safe: str = "/", encoding: Optional[str] = None, errors: Optional[str] = None, ) -> str: """Percent-encodes a string so it can be used in an HTTP request target""" # Redefines 'urllib.parse.quote()' to always have the '~' character # within the 'ALWAYS_SAFE' list. The character was added in Python 3.7 safe = "".join(_QUOTE_ALWAYS_SAFE.union(set(safe))) return _quote(string, safe, encoding=encoding, errors=errors) # type: ignore[arg-type] def basic_auth_to_header(basic_auth: Tuple[str, str]) -> str: """Converts a 2-tuple into a 'Basic' HTTP Authorization header""" if ( not isinstance(basic_auth, tuple) or len(basic_auth) != 2 or any(not isinstance(item, (str, bytes)) for item in basic_auth) ): raise ValueError( "'basic_auth' must be a 2-tuple of str/bytes (username, password)" ) return ( f"Basic {base64.b64encode(b':'.join(to_bytes(x) for x in basic_auth)).decode()}" ) def url_to_node_config( url: str, use_default_ports_for_scheme: bool = False ) -> NodeConfig: """Constructs a :class:`elastic_transport.NodeConfig` instance from a URL. If a username/password are specified in the URL they are converted to an 'Authorization' header. Always fills in a default port for HTTPS. :param url: URL to transform into a NodeConfig. :param use_default_ports_for_scheme: If 'True' will resolve default ports for HTTP. """ try: parsed_url = parse_url(url) except LocationParseError: raise ValueError(f"Could not parse URL {url!r}") from None parsed_port: Optional[int] = parsed_url.port if parsed_url.port is None and parsed_url.scheme is not None: # Always fill in a default port for HTTPS if parsed_url.scheme == "https": parsed_port = 443 # Only fill HTTP default port when asked to explicitly elif parsed_url.scheme == "http" and use_default_ports_for_scheme: parsed_port = 80 if any( component in (None, "") for component in (parsed_url.scheme, parsed_url.host, parsed_port) ): raise ValueError( "URL must include a 'scheme', 'host', and 'port' component (ie 'https://localhost:9200')" ) assert parsed_url.scheme is not None assert parsed_url.host is not None assert parsed_port is not None headers = {} if parsed_url.auth: # `urllib3.util.url_parse` ensures `parsed_url` is correctly # percent-encoded but does not percent-decode userinfo, so we have to # do it ourselves to build the basic auth header correctly. encoded_username, _, encoded_password = parsed_url.auth.partition(":") username = urllib.parse.unquote(encoded_username) password = urllib.parse.unquote(encoded_password) headers["authorization"] = basic_auth_to_header((username, password)) host = parsed_url.host.strip("[]") if not parsed_url.path or parsed_url.path == "/": path_prefix = "" else: path_prefix = parsed_url.path return NodeConfig( scheme=parsed_url.scheme, host=host, port=parsed_port, path_prefix=path_prefix, headers=headers, ) fixup_module_metadata(__name__, globals()) del fixup_module_metadata elastic-transport-python-8.17.1/elastic_transport/py.typed000066400000000000000000000000001476450415400240440ustar00rootroot00000000000000elastic-transport-python-8.17.1/noxfile.py000066400000000000000000000057251476450415400206460ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import nox SOURCE_FILES = ( "noxfile.py", "setup.py", "elastic_transport/", "utils/", "tests/", "docs/sphinx/", ) @nox.session() def format(session): session.install("black~=24.0", "isort", "pyupgrade") session.run("black", "--target-version=py37", *SOURCE_FILES) session.run("isort", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "fix", *SOURCE_FILES) lint(session) @nox.session def lint(session): session.install( "flake8", "black~=24.0", "isort", "mypy==1.7.1", "types-requests", "types-certifi", ) # https://github.com/python/typeshed/issues/10786 session.run( "python", "-m", "pip", "uninstall", "--yes", "types-urllib3", silent=True ) session.install(".[develop]") session.run("black", "--check", "--target-version=py37", *SOURCE_FILES) session.run("isort", "--check", *SOURCE_FILES) session.run("flake8", "--ignore=E501,W503,E203,E704", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("mypy", "--strict", "--show-error-codes", "elastic_transport/") @nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]) def test(session): session.install(".[develop]") session.run( "pytest", "--cov=elastic_transport", *(session.posargs or ("tests/",)), env={"PYTHONWARNINGS": "always::DeprecationWarning"}, ) session.run("coverage", "report", "-m") @nox.session(name="test-min-deps", python="3.8") def test_min_deps(session): session.install("-r", "requirements-min.txt", ".[develop]", silent=False) session.run( "pytest", "--cov=elastic_transport", *(session.posargs or ("tests/",)), env={"PYTHONWARNINGS": "always::DeprecationWarning"}, ) session.run("coverage", "report", "-m") @nox.session(python="3") def docs(session): session.install(".[develop]") session.chdir("docs/sphinx") session.run( "sphinx-build", "-T", "-E", "-b", "html", "-d", "_build/doctrees", "-D", "language=en", ".", "_build/html", ) elastic-transport-python-8.17.1/requirements-min.txt000066400000000000000000000000761476450415400226670ustar00rootroot00000000000000requests==2.26.0 urllib3==1.26.2 aiohttp==3.8.0 httpx==0.27.0 elastic-transport-python-8.17.1/setup.cfg000066400000000000000000000003221476450415400204350ustar00rootroot00000000000000[isort] profile = black [tool:pytest] addopts = -vvv --cov-report=term-missing --cov=elastic_transport asyncio_default_fixture_loop_scope = "function" [coverage:report] omit = elastic_transport/_compat.py elastic-transport-python-8.17.1/setup.py000066400000000000000000000065361476450415400203430ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import os import re from setuptools import find_packages, setup base_dir = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(base_dir, "elastic_transport/_version.py")) as f: version = re.search(r"__version__\s+=\s+\"([^\"]+)\"", f.read()).group(1) with open(os.path.join(base_dir, "README.md")) as f: long_description = f.read() packages = [ package for package in find_packages(exclude=["tests"]) if package.startswith("elastic_transport") ] setup( name="elastic-transport", description="Transport classes and utilities shared among Python Elastic client libraries", long_description=long_description, long_description_content_type="text/markdown", version=version, author="Elastic Client Library Maintainers", author_email="client-libs@elastic.co", url="https://github.com/elastic/elastic-transport-python", project_urls={ "Source Code": "https://github.com/elastic/elastic-transport-python", "Issue Tracker": "https://github.com/elastic/elastic-transport-python/issues", "Documentation": "https://elastic-transport-python.readthedocs.io", }, package_data={"elastic_transport": ["py.typed"]}, packages=packages, install_requires=[ "urllib3>=1.26.2, <3", "certifi", ], python_requires=">=3.8", extras_require={ "develop": [ "pytest", "pytest-cov", "pytest-mock", "pytest-asyncio", "pytest-httpserver", "trustme", "requests", "aiohttp", "httpx", "respx", "opentelemetry-api", "opentelemetry-sdk", "orjson", # Override Read the Docs default (sphinx<2) "sphinx>2", "furo", "sphinx-autodoc-typehints", ], }, classifiers=[ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Developers", "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ], ) elastic-transport-python-8.17.1/tests/000077500000000000000000000000001476450415400177615ustar00rootroot00000000000000elastic-transport-python-8.17.1/tests/__init__.py000066400000000000000000000014231476450415400220720ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. elastic-transport-python-8.17.1/tests/async_/000077500000000000000000000000001476450415400212355ustar00rootroot00000000000000elastic-transport-python-8.17.1/tests/async_/__init__.py000066400000000000000000000014231476450415400233460ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. elastic-transport-python-8.17.1/tests/async_/test_async_transport.py000066400000000000000000000500271476450415400261030ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import random import re import ssl import sys import time import warnings from unittest import mock import pytest from elastic_transport import ( AiohttpHttpNode, AsyncTransport, ConnectionError, ConnectionTimeout, HttpxAsyncHttpNode, NodeConfig, RequestsHttpNode, SniffingError, SniffOptions, TransportError, TransportWarning, Urllib3HttpNode, ) from elastic_transport._node._base import DEFAULT_USER_AGENT from elastic_transport.client_utils import DEFAULT from tests.conftest import AsyncDummyNode @pytest.mark.asyncio async def test_async_transport_httpbin(httpbin_node_config): t = AsyncTransport([httpbin_node_config], meta_header=False) resp, data = await t.perform_request("GET", "/anything?key=value") assert resp.status == 200 assert data["method"] == "GET" assert data["url"] == "https://httpbin.org/anything?key=value" assert data["args"] == {"key": "value"} data["headers"].pop("X-Amzn-Trace-Id", None) assert data["headers"] == {"User-Agent": DEFAULT_USER_AGENT, "Host": "httpbin.org"} @pytest.mark.skipif( sys.version_info < (3, 8), reason="Mock didn't support async before Python 3.8" ) @pytest.mark.asyncio async def test_transport_close_node_pool(): t = AsyncTransport([NodeConfig("http", "localhost", 443)]) with mock.patch.object(t.node_pool.all()[0], "close") as node_close: await t.close() node_close.assert_called_with() @pytest.mark.asyncio async def test_request_with_custom_user_agent_header(): t = AsyncTransport( [NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode, meta_header=False, ) await t.perform_request("GET", "/", headers={"user-agent": "my-custom-value/1.2.3"}) assert 1 == len(t.node_pool.get().calls) assert { "body": None, "request_timeout": DEFAULT, "headers": {"user-agent": "my-custom-value/1.2.3"}, } == t.node_pool.get().calls[0][1] @pytest.mark.asyncio async def test_body_gets_encoded_into_bytes(): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode) await t.perform_request( "GET", "/", headers={"Content-type": "application/json"}, body={"key": "你好"} ) calls = t.node_pool.get().calls assert 1 == len(calls) args, kwargs = calls[0] assert ("GET", "/") == args assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}' @pytest.mark.asyncio async def test_body_bytes_get_passed_untouched(): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode) body = b"\xe4\xbd\xa0\xe5\xa5\xbd" await t.perform_request( "GET", "/", body=body, headers={"Content-Type": "application/json"} ) calls = t.node_pool.get().calls assert 1 == len(calls) args, kwargs = calls[0] assert ("GET", "/") == args assert kwargs["body"] == b"\xe4\xbd\xa0\xe5\xa5\xbd" def test_kwargs_passed_on_to_node_pool(): dt = object() t = AsyncTransport( [NodeConfig("http", "localhost", 80)], dead_node_backoff_factor=dt, max_dead_node_backoff=dt, ) assert dt is t.node_pool.dead_node_backoff_factor assert dt is t.node_pool.max_dead_node_backoff @pytest.mark.asyncio async def test_request_will_fail_after_x_retries(): t = AsyncTransport( [ NodeConfig( "http", "localhost", 80, _extras={"exception": ConnectionError("abandon ship")}, ) ], node_class=AsyncDummyNode, ) with pytest.raises(ConnectionError) as e: await t.perform_request("GET", "/") assert 4 == len(t.node_pool.get().calls) assert len(e.value.errors) == 3 assert all(isinstance(error, ConnectionError) for error in e.value.errors) @pytest.mark.parametrize("retry_on_timeout", [True, False]) @pytest.mark.asyncio async def test_retry_on_timeout(retry_on_timeout): t = AsyncTransport( [ NodeConfig( "http", "localhost", 80, _extras={"exception": ConnectionTimeout("abandon ship")}, ), NodeConfig( "http", "localhost", 81, _extras={"exception": ConnectionError("error!")}, ), ], node_class=AsyncDummyNode, max_retries=1, retry_on_timeout=retry_on_timeout, randomize_nodes_in_pool=False, ) if retry_on_timeout: with pytest.raises(ConnectionError) as e: await t.perform_request("GET", "/") assert len(e.value.errors) == 1 assert isinstance(e.value.errors[0], ConnectionTimeout) else: with pytest.raises(ConnectionTimeout) as e: await t.perform_request("GET", "/") assert len(e.value.errors) == 0 @pytest.mark.asyncio async def test_retry_on_status(): t = AsyncTransport( [ NodeConfig("http", "localhost", 80, _extras={"status": 404}), NodeConfig( "http", "localhost", 81, _extras={"status": 401}, ), NodeConfig( "http", "localhost", 82, _extras={"status": 403}, ), NodeConfig( "http", "localhost", 83, _extras={"status": 555}, ), ], node_class=AsyncDummyNode, node_selector_class="round_robin", retry_on_status=(401, 403, 404), randomize_nodes_in_pool=False, max_retries=5, ) meta, _ = await t.perform_request("GET", "/") assert meta.status == 555 # Assert that every node is called once node_calls = [len(node.calls) for node in t.node_pool.all()] assert node_calls == [ 1, 1, 1, 1, ] @pytest.mark.asyncio async def test_failed_connection_will_be_marked_as_dead(): t = AsyncTransport( [ NodeConfig( "http", "localhost", 80, _extras={"exception": ConnectionError("abandon ship")}, ), NodeConfig( "http", "localhost", 81, _extras={"exception": ConnectionError("abandon ship")}, ), ], max_retries=3, node_class=AsyncDummyNode, ) with pytest.raises(ConnectionError) as e: await t.perform_request("GET", "/") assert 0 == len(t.node_pool._alive_nodes) assert 2 == len(t.node_pool._dead_nodes.queue) assert len(e.value.errors) == 3 assert all(isinstance(error, ConnectionError) for error in e.value.errors) @pytest.mark.asyncio async def test_resurrected_connection_will_be_marked_as_live_on_success(): for method in ("GET", "HEAD"): t = AsyncTransport( [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81), ], node_class=AsyncDummyNode, ) node1 = t.node_pool.get() node2 = t.node_pool.get() t.node_pool.mark_dead(node1) t.node_pool.mark_dead(node2) await t.perform_request(method, "/") assert 1 == len(t.node_pool._alive_nodes) assert 1 == len(t.node_pool._dead_consecutive_failures) assert 1 == len(t.node_pool._dead_nodes.queue) @pytest.mark.asyncio async def test_mark_dead_error_doesnt_raise(): t = AsyncTransport( [ NodeConfig("http", "localhost", 80, _extras={"status": 502}), NodeConfig("http", "localhost", 81), ], retry_on_status=(502,), node_class=AsyncDummyNode, randomize_nodes_in_pool=False, ) bad_node = t.node_pool._all_nodes[NodeConfig("http", "localhost", 80)] with mock.patch.object(t.node_pool, "mark_dead") as mark_dead, mock.patch.object( t, "sniff" ) as sniff: sniff.side_effect = TransportError("sniffing error!") await t.perform_request("GET", "/") mark_dead.assert_called_with(bad_node) @pytest.mark.asyncio async def test_node_class_as_string(): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="aiohttp") assert isinstance(t.node_pool.get(), AiohttpHttpNode) t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="httpxasync") assert isinstance(t.node_pool.get(), HttpxAsyncHttpNode) with pytest.raises(ValueError) as e: AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="huh?") assert str(e.value) == ( "Unknown option for node_class: 'huh?'. " "Available options are: 'aiohttp', 'httpxasync', 'requests', 'urllib3'" ) @pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)]) @pytest.mark.asyncio async def test_head_response_true(status, boolean): t = AsyncTransport( [NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})], node_class=AsyncDummyNode, ) resp, data = await t.perform_request("HEAD", "/") assert resp.status == status assert data is None @pytest.mark.asyncio async def test_head_response_false(): t = AsyncTransport( [NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})], node_class=AsyncDummyNode, ) meta, resp = await t.perform_request("HEAD", "/") assert meta.status == 404 assert resp is None # 404s don't count as a dead node status. assert 0 == len(t.node_pool._dead_nodes.queue) @pytest.mark.parametrize( "node_class, client_short_name", [ ("aiohttp", "ai"), (AiohttpHttpNode, "ai"), ("httpxasync", "hx"), (HttpxAsyncHttpNode, "hx"), ], ) @pytest.mark.asyncio async def test_transport_client_meta_node_class(node_class, client_short_name): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class) assert ( t._transport_client_meta[3] == t.node_pool.node_class._CLIENT_META_HTTP_CLIENT ) assert t._transport_client_meta[3][0] == client_short_name assert re.match( rf"^et=[0-9.]+p?,py=[0-9.]+p?,t=[0-9.]+p?,{client_short_name}=[0-9.]+p?$", ",".join(f"{k}={v}" for k, v in t._transport_client_meta), ) @pytest.mark.asyncio async def test_transport_default_client_meta_node_class(): # Defaults to aiohttp t = AsyncTransport( [NodeConfig("http", "localhost", 80)], client_meta_service=("es", "8.0.0p") ) assert t._transport_client_meta[3][0] == "ai" assert [x[0] for x in t._transport_client_meta[:3]] == ["es", "py", "t"] @pytest.mark.parametrize( "node_class", ["urllib3", "requests", Urllib3HttpNode, RequestsHttpNode], ) def test_transport_and_node_are_async(node_class): with pytest.raises(ValueError) as e: AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class) assert ( str(e.value) == "Specified 'node_class' is not async, should be async instead" ) @pytest.mark.asyncio async def test_sniff_on_start(): calls = [] def sniff_callback(*args): nonlocal calls calls.append(args) return [NodeConfig("http", "localhost", 80)] t = AsyncTransport( [NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode, sniff_on_start=True, sniff_callback=sniff_callback, ) assert len(calls) == 0 await t._async_call() assert len(calls) == 1 await t.perform_request("GET", "/") assert len(calls) == 1 transport, sniff_options = calls[0] assert transport is t assert sniff_options == SniffOptions(is_initial_sniff=True, sniff_timeout=0.5) @pytest.mark.asyncio async def test_sniff_before_requests(): calls = [] def sniff_callback(*args): nonlocal calls calls.append(args) return [] t = AsyncTransport( [NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode, sniff_before_requests=True, sniff_callback=sniff_callback, ) assert len(calls) == 0 await t.perform_request("GET", "/") await t._sniffing_task assert len(calls) == 1 transport, sniff_options = calls[0] assert transport is t assert sniff_options == SniffOptions(is_initial_sniff=False, sniff_timeout=0.5) @pytest.mark.asyncio async def test_sniff_on_node_failure(): calls = [] def sniff_callback(*args): nonlocal calls calls.append(args) return [] t = AsyncTransport( [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81, _extras={"status": 500}), ], randomize_nodes_in_pool=False, node_selector_class="round_robin", node_class=AsyncDummyNode, max_retries=1, sniff_on_node_failure=True, sniff_callback=sniff_callback, ) assert t._sniffing_task is None assert len(calls) == 0 await t.perform_request("GET", "/") # 200 assert t._sniffing_task is None assert len(calls) == 0 await t.perform_request("GET", "/") # 500 await t._sniffing_task assert len(calls) == 1 transport, sniff_options = calls[0] assert transport is t assert sniff_options == SniffOptions(is_initial_sniff=False, sniff_timeout=0.5) @pytest.mark.parametrize( "kwargs", [ {"sniff_on_start": True}, {"sniff_on_node_failure": True}, {"sniff_before_requests": True}, ], ) @pytest.mark.asyncio async def test_error_with_sniffing_enabled_without_callback(kwargs): with pytest.raises(ValueError) as e: AsyncTransport([NodeConfig("http", "localhost", 80)], **kwargs) assert str(e.value) == "Enabling sniffing requires specifying a 'sniff_callback'" @pytest.mark.asyncio async def test_error_sniffing_callback_without_sniffing_enabled(): with pytest.raises(ValueError) as e: AsyncTransport( [NodeConfig("http", "localhost", 80)], sniff_callback=lambda *_: [] ) assert str(e.value) == ( "Using 'sniff_callback' requires enabling sniffing via 'sniff_on_start', " "'sniff_before_requests' or 'sniff_on_node_failure'" ) @pytest.mark.asyncio async def test_heterogeneous_node_config_warning_with_sniffing(): with warnings.catch_warnings(record=True) as w: # SSLContext objects cannot be compared and are thus ignored context = ssl.create_default_context() AsyncTransport( [ NodeConfig( "https", "localhost", 80, path_prefix="/a", ssl_context=context ), NodeConfig( "https", "localhost", 81, path_prefix="/b", ssl_context=context ), ], sniff_on_start=True, sniff_callback=lambda *_: [ NodeConfig("https", "localhost", 80, path_prefix="/a") ], ) assert len(w) == 1 assert w[0].category == TransportWarning assert str(w[0].message) == ( "Detected NodeConfig instances with different options. It's " "recommended to keep all options except for 'host' and 'port' " "the same for sniffing to work reliably." ) @pytest.mark.parametrize("async_sniff_callback", [True, False]) @pytest.mark.asyncio async def test_sniffed_nodes_added_to_pool(async_sniff_callback): sniffed_nodes = [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81), ] loop = asyncio.get_running_loop() sniffed_at = 0.0 # Test that we accept both sync and async sniff_callbacks if async_sniff_callback: async def sniff_callback(*_): nonlocal loop, sniffed_at await asyncio.sleep(0.1) sniffed_at = loop.time() return sniffed_nodes else: def sniff_callback(*_): nonlocal loop, sniffed_at time.sleep(0.1) sniffed_at = loop.time() return sniffed_nodes t = AsyncTransport( [ NodeConfig("http", "localhost", 80), ], node_class=AsyncDummyNode, sniff_before_requests=True, sniff_callback=sniff_callback, ) assert len(t.node_pool) == 1 request_at = loop.time() await t.perform_request("GET", "/") response_at = loop.time() await t._sniffing_task assert 0.1 <= (sniffed_at - request_at) <= 0.15 assert 0 <= response_at - request_at < 0.05 # The node pool knows when nodes are already in the pool # so we shouldn't get duplicates after sniffing. assert len(t.node_pool.all()) == 2 assert len(t.node_pool) == 2 assert set(sniffed_nodes) == {node.config for node in t.node_pool.all()} @pytest.mark.asyncio async def test_sniff_error_resets_lock_and_last_sniffed_at(): def sniff_error(*_): raise TransportError("This is an error!") t = AsyncTransport( [ NodeConfig("http", "localhost", 80), ], node_class=AsyncDummyNode, sniff_on_start=True, sniff_callback=sniff_error, ) last_sniffed_at = t._last_sniffed_at with pytest.raises(TransportError) as e: await t.perform_request("GET", "/") assert str(e.value) == "This is an error!" assert t._last_sniffed_at == last_sniffed_at assert t._sniffing_task.done() async def _empty_sniff(*_): # Used in the below test to mock an empty sniff attempt await asyncio.sleep(0) return [] @pytest.mark.parametrize("sniff_callback", [lambda *_: [], _empty_sniff]) @pytest.mark.asyncio async def test_sniff_on_start_no_results_errors(sniff_callback): t = AsyncTransport( [ NodeConfig("http", "localhost", 80), ], node_class=AsyncDummyNode, sniff_on_start=True, sniff_callback=sniff_callback, ) with pytest.raises(SniffingError) as e: await t._async_call() assert ( str(e.value) == "No viable nodes were discovered on the initial sniff attempt" ) @pytest.mark.parametrize("pool_size", [1, 8]) @pytest.mark.asyncio async def test_multiple_tasks_test(pool_size): node_configs = [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81), NodeConfig("http", "localhost", 82), NodeConfig("http", "localhost", 83, _extras={"status": 500}), ] async def sniff_callback(*_): await asyncio.sleep(random.random()) return node_configs t = AsyncTransport( node_configs, retry_on_status=[500], max_retries=5, node_class=AsyncDummyNode, sniff_on_start=True, sniff_before_requests=True, sniff_on_node_failure=True, sniff_callback=sniff_callback, ) loop = asyncio.get_running_loop() start = loop.time() async def run_requests(): successful_requests = 0 while loop.time() - start < 2: await t.perform_request("GET", "/") successful_requests += 1 return successful_requests tasks = [loop.create_task(run_requests()) for _ in range(pool_size * 2)] assert sum([await task for task in tasks]) >= 1000 @pytest.mark.asyncio async def test_httpbin(httpbin_node_config): t = AsyncTransport([httpbin_node_config]) resp = await t.perform_request("GET", "/anything") assert resp.meta.status == 200 assert isinstance(resp.body, dict) elastic-transport-python-8.17.1/tests/async_/test_httpbin.py000066400000000000000000000100051476450415400243120ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import dataclasses import json import pytest from elastic_transport import AiohttpHttpNode, AsyncTransport from elastic_transport._node._base import DEFAULT_USER_AGENT from ..test_httpbin import parse_httpbin @pytest.mark.asyncio async def test_simple_request(httpbin_node_config): t = AsyncTransport([httpbin_node_config]) resp, data = await t.perform_request( "GET", "/anything?key[]=1&key[]=2&q1&q2=", headers={"Custom": "headeR", "content-type": "application/json"}, body={"JSON": "body"}, ) assert resp.status == 200 assert data["method"] == "GET" assert data["url"] == "https://httpbin.org/anything?key[]=1&key[]=2&q1&q2=" # httpbin makes no-value query params into '' assert data["args"] == { "key[]": ["1", "2"], "q1": "", "q2": "", } assert data["data"] == '{"JSON":"body"}' assert data["json"] == {"JSON": "body"} request_headers = { "Content-Type": "application/json", "Content-Length": "15", "Custom": "headeR", "Host": "httpbin.org", } assert all(v == data["headers"][k] for k, v in request_headers.items()) @pytest.mark.asyncio async def test_node(httpbin_node_config): def new_node(**kwargs): return AiohttpHttpNode(dataclasses.replace(httpbin_node_config, **kwargs)) node = new_node() resp, data = await node.perform_request("GET", "/anything") assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "GET", "url": "https://httpbin.org/anything", } node = new_node(http_compress=True) resp, data = await node.perform_request("GET", "/anything") assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Accept-Encoding": "gzip", "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "GET", "url": "https://httpbin.org/anything", } resp, data = await node.perform_request("GET", "/anything", body=b"hello, world!") assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Accept-Encoding": "gzip", "Content-Encoding": "gzip", "Content-Type": "application/octet-stream", "Content-Length": "33", "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "GET", "url": "https://httpbin.org/anything", } resp, data = await node.perform_request( "POST", "/anything", body=json.dumps({"key": "value"}).encode("utf-8"), headers={"content-type": "application/json"}, ) assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Accept-Encoding": "gzip", "Content-Encoding": "gzip", "Content-Length": "36", "Content-Type": "application/json", "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "POST", "url": "https://httpbin.org/anything", } elastic-transport-python-8.17.1/tests/async_/test_httpserver.py000066400000000000000000000023021476450415400250510ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import warnings import pytest from elastic_transport import AsyncTransport @pytest.mark.asyncio async def test_simple_request(https_server_ip_node_config): with warnings.catch_warnings(): warnings.simplefilter("error") t = AsyncTransport([https_server_ip_node_config]) resp, data = await t.perform_request("GET", "/foobar") assert resp.status == 200 assert data == {"foo": "bar"} elastic-transport-python-8.17.1/tests/conftest.py000066400000000000000000000105341476450415400221630ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import hashlib import logging import socket import ssl import pytest import trustme from pytest_httpserver import HTTPServer from elastic_transport import ApiResponseMeta, BaseNode, HttpHeaders, NodeConfig from elastic_transport._node import NodeApiResponse class DummyNode(BaseNode): def __init__(self, config: NodeConfig): super().__init__(config) self.exception = config._extras.pop("exception", None) self.status = config._extras.pop("status", 200) self.body = config._extras.pop("body", b"{}") self.calls = [] self._headers = config._extras.pop("headers", {}) def perform_request(self, *args, **kwargs): self.calls.append((args, kwargs)) if self.exception: raise self.exception meta = ApiResponseMeta( node=self.config, duration=0.0, http_version="1.1", status=self.status, headers=HttpHeaders(self._headers), ) return NodeApiResponse(meta, self.body) class AsyncDummyNode(DummyNode): async def perform_request(self, *args, **kwargs): self.calls.append((args, kwargs)) if self.exception: raise self.exception meta = ApiResponseMeta( node=self.config, duration=0.0, http_version="1.1", status=self.status, headers=HttpHeaders(self._headers), ) return NodeApiResponse(meta, self.body) @pytest.fixture(scope="session", params=[True, False]) def httpbin_cert_fingerprint(request) -> str: """Gets the SHA256 fingerprint of the certificate for 'httpbin.org'""" sock = socket.create_connection(("httpbin.org", 443)) ctx = ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE sock = ctx.wrap_socket(sock) digest = hashlib.sha256(sock.getpeercert(binary_form=True)).hexdigest() assert len(digest) == 64 sock.close() if request.param: return digest else: return ":".join([digest[i : i + 2] for i in range(0, len(digest), 2)]) @pytest.fixture(scope="session") def httpbin_node_config() -> NodeConfig: try: sock = socket.create_connection(("httpbin.org", 443)) except Exception as e: pytest.skip(f"Couldn't connect to httpbin.org, internet not connected? {e}") sock.close() return NodeConfig( "https", "httpbin.org", 443, verify_certs=False, ssl_show_warn=False ) @pytest.fixture(scope="function", autouse=True) def elastic_transport_logging(): for name in ("node", "node_pool", "transport"): logger = logging.getLogger(f"elastic_transport.{name}") for handler in logger.handlers[:]: logger.removeHandler(handler) @pytest.fixture(scope="session") def https_server_ip_node_config(tmp_path_factory: pytest.TempPathFactory) -> NodeConfig: ca = trustme.CA() tmpdir = tmp_path_factory.mktemp("certs") ca_cert_path = str(tmpdir / "ca.pem") ca.cert_pem.write_to_path(ca_cert_path) localhost_cert = ca.issue_cert("127.0.0.1") context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) crt = localhost_cert.cert_chain_pems[0] key = localhost_cert.private_key_pem with crt.tempfile() as crt_file, key.tempfile() as key_file: context.load_cert_chain(crt_file, key_file) server = HTTPServer(ssl_context=context) server.expect_request("/foobar").respond_with_json({"foo": "bar"}) server.start() yield NodeConfig("https", "127.0.0.1", server.port, ca_certs=ca_cert_path) server.clear() if server.is_running(): server.stop() elastic-transport-python-8.17.1/tests/node/000077500000000000000000000000001476450415400207065ustar00rootroot00000000000000elastic-transport-python-8.17.1/tests/node/__init__.py000066400000000000000000000014231476450415400230170ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. elastic-transport-python-8.17.1/tests/node/test_base.py000066400000000000000000000031031476450415400232260ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest from elastic_transport import ( AiohttpHttpNode, HttpxAsyncHttpNode, NodeConfig, RequestsHttpNode, Urllib3HttpNode, ) from elastic_transport._node._base import ssl_context_from_node_config @pytest.mark.parametrize( "node_cls", [Urllib3HttpNode, RequestsHttpNode, AiohttpHttpNode, HttpxAsyncHttpNode] ) def test_unknown_parameter(node_cls): with pytest.raises(TypeError): node_cls(unknown_option=1) @pytest.mark.parametrize( "host, check_hostname", [ ("127.0.0.1", False), ("::1", False), ("localhost", True), ], ) def test_ssl_context_from_node_config(host, check_hostname): node_config = NodeConfig("https", host, 443) ctx = ssl_context_from_node_config(node_config) assert ctx.check_hostname == check_hostname elastic-transport-python-8.17.1/tests/node/test_http_aiohttp.py000066400000000000000000000305751476450415400250400ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import gzip import json import warnings import aiohttp import pytest from multidict import CIMultiDict from elastic_transport import AiohttpHttpNode, NodeConfig from elastic_transport._node._base import DEFAULT_USER_AGENT class TestAiohttpHttpNode: @pytest.mark.asyncio async def _get_mock_node(self, node_config, response_body=b"{}"): node = AiohttpHttpNode(node_config) node._create_aiohttp_session() def _dummy_request(*args, **kwargs): class DummyResponse: async def __aenter__(self, *_, **__): return self async def __aexit__(self, *_, **__): pass async def read(self): return response_body if args[0] != "HEAD" else b"" async def release(self): return None dummy_response = DummyResponse() dummy_response.headers = CIMultiDict() dummy_response.status = 200 _dummy_request.call_args = (args, kwargs) return dummy_response node.session.request = _dummy_request return node @pytest.mark.asyncio async def test_aiohttp_options(self): node = await self._get_mock_node( NodeConfig(scheme="http", host="localhost", port=80) ) await node.perform_request( "GET", "/index", body=b"hello, world!", headers={"key": "value"}, ) args, kwargs = node.session.request.call_args assert args == ("GET", "http://localhost:80/index") assert kwargs == { "data": b"hello, world!", "headers": { "connection": "keep-alive", "key": "value", "user-agent": DEFAULT_USER_AGENT, }, "timeout": aiohttp.ClientTimeout( total=10, connect=None, sock_read=None, sock_connect=None, ), } @pytest.mark.asyncio async def test_aiohttp_options_fingerprint(self): node = await self._get_mock_node( NodeConfig( scheme="https", host="localhost", port=443, ssl_assert_fingerprint=("00:" * 32).strip(":"), ) ) await node.perform_request( "GET", "/", ) args, kwargs = node.session.request.call_args assert args == ("GET", "https://localhost:443/") # aiohttp.Fingerprint() doesn't define equality fingerprint: aiohttp.Fingerprint = kwargs.pop("ssl") assert fingerprint.fingerprint == b"\x00" * 32 assert kwargs == { "data": None, "headers": {"connection": "keep-alive", "user-agent": DEFAULT_USER_AGENT}, "timeout": aiohttp.ClientTimeout( total=10, connect=None, sock_read=None, sock_connect=None, ), } @pytest.mark.parametrize( "options", [(5, 5, 5), (None, 5, 5), (5, None, 0), (None, None, 0), (5, 5), (None, 0)], ) @pytest.mark.asyncio async def test_aiohttp_options_timeout(self, options): if len(options) == 3: constructor_timeout, request_timeout, aiohttp_timeout = options node = await self._get_mock_node( NodeConfig( scheme="http", host="localhost", port=80, request_timeout=constructor_timeout, ) ) else: request_timeout, aiohttp_timeout = options node = await self._get_mock_node( NodeConfig(scheme="http", host="localhost", port=80) ) await node.perform_request( "GET", "/", request_timeout=request_timeout, ) args, kwargs = node.session.request.call_args assert args == ("GET", "http://localhost:80/") assert kwargs == { "data": None, "headers": {"connection": "keep-alive", "user-agent": DEFAULT_USER_AGENT}, "timeout": aiohttp.ClientTimeout( total=aiohttp_timeout, connect=None, sock_read=None, sock_connect=None, ), } @pytest.mark.asyncio async def test_http_compression(self): node = await self._get_mock_node( NodeConfig(scheme="http", host="localhost", port=80, http_compress=True) ) # 'content-encoding' shouldn't be set at a session level. # Should be applied only if the request is sent with a body. assert "content-encoding" not in node.session.headers await node.perform_request("GET", "/", body=b"{}") args, kwargs = node.session.request.call_args assert kwargs["headers"] == { "accept-encoding": "gzip", "connection": "keep-alive", "content-encoding": "gzip", "user-agent": DEFAULT_USER_AGENT, } assert gzip.decompress(kwargs["data"]) == b"{}" @pytest.mark.parametrize("http_compress", [None, False]) @pytest.mark.asyncio async def test_no_http_compression(self, http_compress): node = await self._get_mock_node( NodeConfig( scheme="http", host="localhost", port=80, http_compress=http_compress ) ) assert "content-encoding" not in node.session.headers await node.perform_request("GET", "/", body=b"{}") args, kwargs = node.session.request.call_args assert kwargs["headers"] == { "connection": "keep-alive", "user-agent": DEFAULT_USER_AGENT, } assert kwargs["data"] == b"{}" @pytest.mark.parametrize("path_prefix", ["url", "/url"]) @pytest.mark.asyncio async def test_uses_https_if_verify_certs_is_off(self, path_prefix): with warnings.catch_warnings(record=True) as w: await self._get_mock_node( NodeConfig( scheme="https", host="localhost", port=443, path_prefix=path_prefix, verify_certs=False, ) ) assert 1 == len(w) assert ( "Connecting to 'https://localhost:443/url' using TLS with verify_certs=False is insecure" == str(w[0].message) ) @pytest.mark.asyncio async def test_uses_https_if_verify_certs_is_off_no_show_warning(self): with warnings.catch_warnings(record=True) as w: node = await self._get_mock_node( NodeConfig( scheme="https", host="localhost", port=443, path_prefix="url", ssl_show_warn=False, ) ) await node.perform_request("GET", "/") assert w == [] @pytest.mark.asyncio async def test_merge_headers(self): node = await self._get_mock_node( NodeConfig( scheme="https", host="localhost", port=443, headers={"h1": "v1", "h2": "v2"}, ) ) resp, _ = await node.perform_request( "GET", "/", headers={"H2": "v2p", "H3": "v3"} ) args, kwargs = node.session.request.call_args assert args == ("GET", "https://localhost:443/") assert kwargs["headers"] == { "connection": "keep-alive", "h1": "v1", "h2": "v2p", "h3": "v3", "user-agent": DEFAULT_USER_AGENT, } @pytest.mark.parametrize("aiohttp_fixed_head_bug", [True, False]) @pytest.mark.asyncio async def test_head_workaround(self, aiohttp_fixed_head_bug): from elastic_transport._node import _http_aiohttp prev = _http_aiohttp._AIOHTTP_FIXED_HEAD_BUG try: _http_aiohttp._AIOHTTP_FIXED_HEAD_BUG = aiohttp_fixed_head_bug node = await self._get_mock_node( NodeConfig( scheme="https", host="localhost", port=443, ) ) resp, data = await node.perform_request("HEAD", "/anything") method, url = node.session.request.call_args[0] assert method == "HEAD" if aiohttp_fixed_head_bug else "GET" assert url == "https://localhost:443/anything" assert resp.status == 200 assert data == b"" finally: _http_aiohttp._AIOHTTP_FIXED_HEAD_BUG = prev @pytest.mark.asyncio async def test_ssl_assert_fingerprint(httpbin_cert_fingerprint): with warnings.catch_warnings(record=True) as w: node = AiohttpHttpNode( NodeConfig( scheme="https", host="httpbin.org", port=443, ssl_assert_fingerprint=httpbin_cert_fingerprint, ) ) resp, _ = await node.perform_request("GET", "/") assert resp.status == 200 assert [str(x.message) for x in w if x.category != DeprecationWarning] == [] @pytest.mark.asyncio async def test_default_headers(): node = AiohttpHttpNode(NodeConfig(scheme="https", host="httpbin.org", port=443)) resp, data = await node.perform_request("GET", "/anything") assert resp.status == 200 headers = json.loads(data)["headers"] headers.pop("X-Amzn-Trace-Id", None) assert headers == {"Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT} @pytest.mark.asyncio async def test_custom_headers(): node = AiohttpHttpNode( NodeConfig( scheme="https", host="httpbin.org", port=443, headers={"accept-encoding": "gzip", "Content-Type": "application/json"}, ) ) resp, data = await node.perform_request( "GET", "/anything", headers={ "conTent-type": "application/x-ndjson", "user-agent": "custom-agent/1.2.3", }, ) assert resp.status == 200 headers = json.loads(data)["headers"] headers.pop("X-Amzn-Trace-Id", None) assert headers == { "Accept-Encoding": "gzip", "Content-Type": "application/x-ndjson", "Host": "httpbin.org", "User-Agent": "custom-agent/1.2.3", } @pytest.mark.asyncio async def test_custom_user_agent(): node = AiohttpHttpNode( NodeConfig( scheme="https", host="httpbin.org", port=443, headers={ "accept-encoding": "gzip", "Content-Type": "application/json", "user-agent": "custom-agent/1.2.3", }, ) ) resp, data = await node.perform_request( "GET", "/anything", ) assert resp.status == 200 headers = json.loads(data)["headers"] headers.pop("X-Amzn-Trace-Id", None) assert headers == { "Accept-Encoding": "gzip", "Content-Type": "application/json", "Host": "httpbin.org", "User-Agent": "custom-agent/1.2.3", } def test_repr(): node = AiohttpHttpNode(NodeConfig(scheme="https", host="localhost", port=443)) assert "" == repr(node) @pytest.mark.asyncio async def test_head(): node = AiohttpHttpNode( NodeConfig(scheme="https", host="httpbin.org", port=443, http_compress=True) ) resp, data = await node.perform_request("HEAD", "/anything") assert resp.status == 200 assert resp.headers["content-type"] == "application/json" assert data == b"" elastic-transport-python-8.17.1/tests/node/test_http_httpx.py000066400000000000000000000130301476450415400245220ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import gzip import ssl import warnings import pytest import respx from elastic_transport import HttpxAsyncHttpNode, NodeConfig from elastic_transport._node._base import DEFAULT_USER_AGENT def create_node(node_config: NodeConfig): return HttpxAsyncHttpNode(node_config) class TestHttpxAsyncNodeCreation: def test_ssl_context(self): ssl_context = ssl.create_default_context() with warnings.catch_warnings(record=True) as w: node = create_node( NodeConfig( scheme="https", host="localhost", port=80, ssl_context=ssl_context, ) ) assert node.client._transport._pool._ssl_context is ssl_context assert len(w) == 0 def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: _ = create_node(NodeConfig("https", "localhost", 443, verify_certs=False)) assert ( str(w[0].message) == "Connecting to 'https://localhost:443' using TLS with verify_certs=False is insecure" ) def test_no_warn_when_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: _ = create_node( NodeConfig( "https", "localhost", 443, verify_certs=False, ssl_show_warn=False, ) ) assert 0 == len(w) def test_ca_certs_with_verify_ssl_false_raises_error(self): with pytest.raises(ValueError) as exc: create_node( NodeConfig( "https", "localhost", 443, ca_certs="/ca/certs", verify_certs=False, ) ) assert ( str(exc.value) == "You cannot use 'ca_certs' when 'verify_certs=False'" ) @pytest.mark.asyncio class TestHttpxAsyncNode: @respx.mock async def test_simple_request(self): node = create_node(NodeConfig(scheme="http", host="localhost", port=80)) respx.get("http://localhost/index") await node.perform_request( "GET", "/index", b"hello world", headers={"key": "value"} ) request = respx.calls.last.request assert request.content == b"hello world" assert { "key": "value", "connection": "keep-alive", "user-agent": DEFAULT_USER_AGENT, }.items() <= request.headers.items() @respx.mock async def test_compression(self): node = create_node( NodeConfig(scheme="http", host="localhost", port=80, http_compress=True) ) respx.get("http://localhost/index") await node.perform_request("GET", "/index", b"hello world") request = respx.calls.last.request assert gzip.decompress(request.content) == b"hello world" assert {"content-encoding": "gzip"}.items() <= request.headers.items() @respx.mock async def test_default_timeout(self): node = create_node( NodeConfig(scheme="http", host="localhost", port=80, request_timeout=10) ) respx.get("http://localhost/index") await node.perform_request("GET", "/index", b"hello world") request = respx.calls.last.request assert request.extensions["timeout"]["connect"] == 10 @respx.mock async def test_overwritten_timeout(self): node = create_node( NodeConfig(scheme="http", host="localhost", port=80, request_timeout=10) ) respx.get("http://localhost/index") await node.perform_request("GET", "/index", b"hello world", request_timeout=15) request = respx.calls.last.request assert request.extensions["timeout"]["connect"] == 15 @respx.mock async def test_merge_headers(self): node = create_node( NodeConfig("http", "localhost", 80, headers={"h1": "v1", "h2": "v2"}) ) respx.get("http://localhost/index") await node.perform_request( "GET", "/index", b"hello world", headers={"h2": "v2p", "h3": "v3"} ) request = respx.calls.last.request assert request.headers["h1"] == "v1" assert request.headers["h2"] == "v2p" assert request.headers["h3"] == "v3" def test_ssl_assert_fingerprint(httpbin_cert_fingerprint): with pytest.raises(ValueError, match="httpx does not support certificate pinning"): HttpxAsyncHttpNode( NodeConfig( scheme="https", host="httpbin.org", port=443, ssl_assert_fingerprint=httpbin_cert_fingerprint, ) ) elastic-transport-python-8.17.1/tests/node/test_http_requests.py000066400000000000000000000214671476450415400252430ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import gzip import ssl import warnings from unittest.mock import Mock, patch import pytest import requests from requests.auth import HTTPBasicAuth from elastic_transport import NodeConfig, RequestsHttpNode from elastic_transport._node._base import DEFAULT_USER_AGENT class TestRequestsHttpNode: def _get_mock_node(self, node_config, response_body=b"{}"): node = RequestsHttpNode(node_config) def _dummy_send(*args, **kwargs): dummy_response = Mock() dummy_response.headers = {} dummy_response.status_code = 200 dummy_response.content = response_body dummy_response.request = args[0] dummy_response.cookies = {} _dummy_send.call_args = (args, kwargs) return dummy_response node.session.send = _dummy_send return node def _get_request(self, node, *args, **kwargs) -> requests.PreparedRequest: resp, data = node.perform_request(*args, **kwargs) status = resp.status assert 200 == status assert b"{}" == data timeout = kwargs.pop("request_timeout", node.config.request_timeout) args, kwargs = node.session.send.call_args assert timeout == kwargs["timeout"] assert 1 == len(args) return args[0] def test_close_session(self): node = RequestsHttpNode(NodeConfig("http", "localhost", 80)) with patch.object(node.session, "close") as pool_close: node.close() pool_close.assert_called_with() def test_ssl_context(self): ctx = ssl.create_default_context() node = RequestsHttpNode(NodeConfig("https", "localhost", 80, ssl_context=ctx)) adapter = node.session.get_adapter("https://localhost:80") assert adapter.poolmanager.connection_pool_kw["ssl_context"] is ctx def test_merge_headers(self): node = self._get_mock_node( NodeConfig("http", "localhost", 80, headers={"h1": "v1", "h2": "v2"}) ) req = self._get_request(node, "GET", "/", headers={"h2": "v2p", "h3": "v3"}) assert req.headers["h1"] == "v1" assert req.headers["h2"] == "v2p" assert req.headers["h3"] == "v3" def test_default_headers(self): node = self._get_mock_node(NodeConfig("http", "localhost", 80)) req = self._get_request(node, "GET", "/") assert req.headers == { "connection": "keep-alive", "user-agent": DEFAULT_USER_AGENT, } def test_no_http_compression(self): node = self._get_mock_node( NodeConfig("http", "localhost", 80, http_compress=False) ) assert not node.config.http_compress assert "accept-encoding" not in node.headers node.perform_request("GET", "/") (req,), _ = node.session.send.call_args assert req.body is None assert "accept-encoding" not in req.headers assert "content-encoding" not in req.headers node.perform_request("GET", "/", body=b"hello, world!") (req,), _ = node.session.send.call_args assert req.body == b"hello, world!" assert "accept-encoding" not in req.headers assert "content-encoding" not in req.headers @pytest.mark.parametrize("empty_body", [None, b""]) def test_http_compression(self, empty_body): node = self._get_mock_node( NodeConfig("http", "localhost", 80, http_compress=True) ) assert node.config.http_compress is True assert node.headers["accept-encoding"] == "gzip" # 'content-encoding' shouldn't be set at a connection level. # Should be applied only if the request is sent with a body. assert "content-encoding" not in node.headers node.perform_request("GET", "/", body=b"{}") (req,), _ = node.session.send.call_args assert gzip.decompress(req.body) == b"{}" assert req.headers["accept-encoding"] == "gzip" assert req.headers["content-encoding"] == "gzip" node.perform_request("GET", "/", body=empty_body) (req,), _ = node.session.send.call_args assert req.body is None assert req.headers["accept-encoding"] == "gzip" print(req.headers) assert "content-encoding" not in req.headers @pytest.mark.parametrize("request_timeout", [None, 15]) def test_timeout_override_default(self, request_timeout): node = self._get_mock_node( NodeConfig("http", "localhost", 80, request_timeout=request_timeout) ) assert node.config.request_timeout == request_timeout node.perform_request("GET", "/") _, kwargs = node.session.send.call_args assert kwargs["timeout"] == request_timeout node.perform_request("GET", "/", request_timeout=5) _, kwargs = node.session.send.call_args assert kwargs["timeout"] == 5 node.perform_request("GET", "/", request_timeout=None) _, kwargs = node.session.send.call_args assert kwargs["timeout"] is None def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: RequestsHttpNode(NodeConfig("https", "localhost", 443, verify_certs=False)) assert 1 == len(w) assert ( "Connecting to 'https://localhost:443' using TLS with verify_certs=False is insecure" == str(w[0].message) ) def test_no_warn_when_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: RequestsHttpNode( NodeConfig( "https", "localhost", 443, verify_certs=False, ssl_show_warn=False ) ) assert 0 == len(w) def test_no_warning_when_using_ssl_context(self): ctx = ssl.create_default_context() with warnings.catch_warnings(record=True) as w: RequestsHttpNode(NodeConfig("https", "localhost", 443, ssl_context=ctx)) assert 0 == len(w) def test_ca_certs_with_verify_ssl_false_raises_error(self): with pytest.raises(ValueError) as e: RequestsHttpNode( NodeConfig( "https", "localhost", 443, ca_certs="/ca/certs", verify_certs=False ) ) assert str(e.value) == "You cannot use 'ca_certs' when 'verify_certs=False'" def test_client_cert_is_used_as_session_cert(self): conn = RequestsHttpNode( NodeConfig("https", "localhost", 443, client_cert="/client/cert") ) assert conn.session.cert == "/client/cert" conn = RequestsHttpNode( NodeConfig( "https", "localhost", 443, client_cert="/client/cert", client_key="/client/key", ) ) assert conn.session.cert == ("/client/cert", "/client/key") def test_ca_certs_is_used_as_session_verify(self): conn = RequestsHttpNode( NodeConfig("https", "localhost", 443, ca_certs="/ca/certs") ) assert conn.session.verify == "/ca/certs" def test_surrogatepass_into_bytes(self): data = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" node = self._get_mock_node( NodeConfig("http", "localhost", 80), response_body=data ) _, data = node.perform_request("GET", "/") assert b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" == data @pytest.mark.parametrize("_extras", [None, {}, {"requests.session.auth": None}]) def test_requests_no_session_auth(self, _extras): node = self._get_mock_node(NodeConfig("http", "localhost", 80, _extras=_extras)) assert node.session.auth is None def test_requests_custom_auth(self): auth = HTTPBasicAuth("username", "password") node = self._get_mock_node( NodeConfig("http", "localhost", 80, _extras={"requests.session.auth": auth}) ) assert node.session.auth is auth node.perform_request("GET", "/") (request,), _ = node.session.send.call_args assert request.headers["authorization"] == "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" elastic-transport-python-8.17.1/tests/node/test_http_urllib3.py000066400000000000000000000212321476450415400247320ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import gzip import re import ssl import warnings from unittest.mock import Mock, patch import pytest import urllib3 from urllib3.response import HTTPHeaderDict from elastic_transport import NodeConfig, TransportError, Urllib3HttpNode from elastic_transport._node._base import DEFAULT_USER_AGENT class TestUrllib3HttpNode: def _get_mock_node(self, node_config, response_body=b"{}"): node = Urllib3HttpNode(node_config) def _dummy_urlopen(*args, **kwargs): dummy_response = Mock() dummy_response.headers = HTTPHeaderDict({}) dummy_response.status = 200 dummy_response.data = response_body _dummy_urlopen.call_args = (args, kwargs) return dummy_response node.pool.urlopen = _dummy_urlopen return node def test_close_pool(self): node = Urllib3HttpNode(NodeConfig("http", "localhost", 80)) with patch.object(node.pool, "close") as pool_close: node.close() pool_close.assert_called_with() def test_ssl_context(self): ctx = ssl.create_default_context() node = Urllib3HttpNode(NodeConfig("https", "localhost", 80, ssl_context=ctx)) assert len(node.pool.conn_kw.keys()) == 1 assert isinstance(node.pool.conn_kw["ssl_context"], ssl.SSLContext) assert node.scheme == "https" def test_no_http_compression(self): node = self._get_mock_node( NodeConfig("http", "localhost", 80, http_compress=False) ) assert not node.config.http_compress assert "accept-encoding" not in node.headers node.perform_request("GET", "/") (_, _), kwargs = node.pool.urlopen.call_args assert kwargs["body"] is None assert "accept-encoding" not in kwargs["headers"] assert "content-encoding" not in kwargs["headers"] node.perform_request("GET", "/", body=b"hello, world!") (_, _), kwargs = node.pool.urlopen.call_args assert kwargs["body"] == b"hello, world!" assert "accept-encoding" not in kwargs["headers"] assert "content-encoding" not in kwargs["headers"] @pytest.mark.parametrize( ["request_target", "expected_target"], [ ("/_search", "/prefix/_search"), ("/?key=val", "/prefix/?key=val"), ("/_search?key=val/", "/prefix/_search?key=val/"), ], ) def test_path_prefix_applied_to_target(self, request_target, expected_target): node = self._get_mock_node( NodeConfig("http", "localhost", 80, path_prefix="/prefix") ) node.perform_request("GET", request_target) (_, target), _ = node.pool.urlopen.call_args assert target == expected_target @pytest.mark.parametrize("empty_body", [None, b""]) def test_http_compression(self, empty_body): node = self._get_mock_node( NodeConfig("http", "localhost", 80, http_compress=True) ) assert node.config.http_compress is True assert node.headers["accept-encoding"] == "gzip" # 'content-encoding' shouldn't be set at a connection level. # Should be applied only if the request is sent with a body. assert "content-encoding" not in node.headers node.perform_request("GET", "/", body=b"{}") (_, _), kwargs = node.pool.urlopen.call_args body = kwargs["body"] assert gzip.decompress(body) == b"{}" assert kwargs["headers"]["accept-encoding"] == "gzip" assert kwargs["headers"]["content-encoding"] == "gzip" node.perform_request("GET", "/", body=empty_body) (_, _), kwargs = node.pool.urlopen.call_args assert kwargs["body"] is None assert kwargs["headers"]["accept-encoding"] == "gzip" assert "content-encoding" not in kwargs["headers"] def test_default_headers(self): node = self._get_mock_node(NodeConfig("http", "localhost", 80)) node.perform_request("GET", "/") (_, _), kwargs = node.pool.urlopen.call_args assert kwargs["headers"] == { "connection": "keep-alive", "user-agent": DEFAULT_USER_AGENT, } @pytest.mark.parametrize("request_timeout", [None, 15]) def test_timeout_override_default(self, request_timeout): node = Urllib3HttpNode( NodeConfig("http", "localhost", 80, request_timeout=request_timeout) ) assert node.config.request_timeout == request_timeout assert node.pool.timeout.total == request_timeout with patch.object(node.pool, "urlopen") as pool_urlopen: resp = Mock() resp.status = 200 resp.headers = {} pool_urlopen.return_value = resp node.perform_request("GET", "/", request_timeout=request_timeout) _, kwargs = pool_urlopen.call_args assert kwargs["timeout"] == request_timeout def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: con = Urllib3HttpNode( NodeConfig("https", "localhost", 443, verify_certs=False) ) assert 1 == len(w) assert ( "Connecting to 'https://localhost:443' using TLS with verify_certs=False is insecure" == str(w[0].message) ) assert isinstance(con.pool, urllib3.HTTPSConnectionPool) def test_no_warn_when_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: con = Urllib3HttpNode( NodeConfig( "https", "localhost", 443, verify_certs=False, ssl_show_warn=False ) ) assert 0 == len(w) assert isinstance(con.pool, urllib3.HTTPSConnectionPool) def test_no_warning_when_using_ssl_context(self): ctx = ssl.create_default_context() with warnings.catch_warnings(record=True) as w: Urllib3HttpNode(NodeConfig("https", "localhost", 443, ssl_context=ctx)) assert 0 == len(w) def test_surrogatepass_into_bytes(self): data = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" con = self._get_mock_node( NodeConfig("http", "localhost", 80), response_body=data ) _, data = con.perform_request("GET", "/") assert b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" == data @pytest.mark.xfail @patch("elastic_transport._node._base.logger") def test_uncompressed_body_logged(self, logger): con = self._get_mock_node(connection_params={"http_compress": True}) con.perform_request("GET", "/", body=b'{"example": "body"}') assert 2 == logger.debug.call_count req, resp = logger.debug.call_args_list assert '> {"example": "body"}' == req[0][0] % req[0][1:] assert "< {}" == resp[0][0] % resp[0][1:] @pytest.mark.xfail @patch("elastic_transport._node._base.logger") def test_failed_request_logs(self, logger): conn = Urllib3HttpNode() with patch.object(conn.pool, "urlopen") as pool_urlopen: resp = Mock() resp.data = b'{"answer":42}' resp.status = 500 resp.headers = {} pool_urlopen.return_value = resp with pytest.raises(TransportError) as e: conn.perform_request( "GET", "/?param=42", b"{}", ) assert repr(e.value) == "InternalServerError({'answer': 42}, status=500)" # log url and duration assert 1 == logger.warning.call_count assert re.match( r"^GET http://localhost/\?param=42 \[status:500 request:0.[0-9]{3}s\]", logger.warning.call_args[0][0] % logger.warning.call_args[0][1:], ) assert 2 == logger.debug.call_count req, resp = logger.debug.call_args_list assert "> {}" == req[0][0] % req[0][1:] assert '< {"answer":42}' == resp[0][0] % resp[0][1:] elastic-transport-python-8.17.1/tests/node/test_tls_versions.py000066400000000000000000000104551476450415400250560ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import functools import socket import ssl import pytest from elastic_transport import ( AiohttpHttpNode, HttpxAsyncHttpNode, NodeConfig, RequestsHttpNode, TlsError, Urllib3HttpNode, ) from elastic_transport._compat import await_if_coro from elastic_transport.client_utils import url_to_node_config TLSv1_0_URL = "https://tls-v1-0.badssl.com:1010" TLSv1_1_URL = "https://tls-v1-1.badssl.com:1011" TLSv1_2_URL = "https://tls-v1-2.badssl.com:1012" node_classes = pytest.mark.parametrize( "node_class", [AiohttpHttpNode, Urllib3HttpNode, RequestsHttpNode, HttpxAsyncHttpNode], ) supported_version_params = [ (TLSv1_0_URL, ssl.PROTOCOL_TLSv1), (TLSv1_1_URL, ssl.PROTOCOL_TLSv1_1), (TLSv1_2_URL, ssl.PROTOCOL_TLSv1_2), (TLSv1_2_URL, None), ] unsupported_version_params = [ (TLSv1_0_URL, None), (TLSv1_1_URL, None), (TLSv1_0_URL, ssl.PROTOCOL_TLSv1_1), (TLSv1_0_URL, ssl.PROTOCOL_TLSv1_2), (TLSv1_1_URL, ssl.PROTOCOL_TLSv1_2), ] try: from ssl import TLSVersion except ImportError: pass else: supported_version_params.extend( [ (TLSv1_0_URL, TLSVersion.TLSv1), (TLSv1_1_URL, TLSVersion.TLSv1_1), (TLSv1_2_URL, TLSVersion.TLSv1_2), ] ) unsupported_version_params.extend( [ (TLSv1_0_URL, TLSVersion.TLSv1_1), (TLSv1_0_URL, TLSVersion.TLSv1_2), (TLSv1_1_URL, TLSVersion.TLSv1_2), (TLSv1_0_URL, TLSVersion.TLSv1_3), (TLSv1_1_URL, TLSVersion.TLSv1_3), (TLSv1_2_URL, TLSVersion.TLSv1_3), ] ) @functools.lru_cache() def tlsv1_1_supported() -> bool: # OpenSSL distributions on Ubuntu/Debian disable TLSv1.1 and before incorrectly. # So we try to detect that and skip tests when needed. try: sock = socket.create_connection(("tls-v1-1.badssl.com", 1011)) ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1) sock = ctx.wrap_socket(sock, server_hostname="tls-v1-1.badssl.com") sock.close() except ssl.SSLError: return False return True @node_classes @pytest.mark.parametrize( ["url", "ssl_version"], supported_version_params, ) @pytest.mark.asyncio async def test_supported_tls_versions(node_class, url: str, ssl_version: int): if url in (TLSv1_0_URL, TLSv1_1_URL) and not tlsv1_1_supported(): pytest.skip("TLSv1.1 isn't supported by this OpenSSL distribution") node_config = url_to_node_config(url).replace(ssl_version=ssl_version) node = node_class(node_config) resp, _ = await await_if_coro(node.perform_request("GET", "/")) assert resp.status == 200 @node_classes @pytest.mark.parametrize( ["url", "ssl_version"], unsupported_version_params, ) @pytest.mark.asyncio async def test_unsupported_tls_versions(node_class, url: str, ssl_version: int): node_config = url_to_node_config(url).replace(ssl_version=ssl_version) node = node_class(node_config) with pytest.raises(TlsError) as e: await await_if_coro(node.perform_request("GET", "/")) assert "unsupported protocol" in str(e.value) or "handshake failure" in str(e.value) @node_classes @pytest.mark.parametrize("ssl_version", [0, "TLSv1", object()]) def test_ssl_version_value_error(node_class, ssl_version): with pytest.raises(ValueError) as e: node_class(NodeConfig("https", "localhost", 9200, ssl_version=ssl_version)) assert str(e.value) == ( f"Unsupported value for 'ssl_version': {ssl_version!r}. Must be either " "'ssl.PROTOCOL_TLSvX' or 'ssl.TLSVersion.TLSvX'" ) elastic-transport-python-8.17.1/tests/node/test_urllib3_chain_certs.py000066400000000000000000000064541476450415400262460ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import sys import warnings import pytest from elastic_transport import NodeConfig, RequestsHttpNode, TlsError, Urllib3HttpNode requires_ssl_assert_fingerprint_in_chain = pytest.mark.skipif( sys.version_info < (3, 10) or sys.implementation.name != "cpython", reason="Requires CPython 3.10+", ) @requires_ssl_assert_fingerprint_in_chain @pytest.mark.parametrize("node_cls", [Urllib3HttpNode, RequestsHttpNode]) def test_ssl_assert_fingerprint_invalid_length(node_cls): with pytest.raises(ValueError) as e: node_cls( NodeConfig( "https", "httpbin.org", 443, ssl_assert_fingerprint="0000", ) ) assert ( str(e.value) == "Fingerprint of invalid length '4', should be one of '32', '40', '64'" ) @requires_ssl_assert_fingerprint_in_chain @pytest.mark.parametrize("node_cls", [Urllib3HttpNode, RequestsHttpNode]) @pytest.mark.parametrize( "ssl_assert_fingerprint", [ "8ecde6884f3d87b1125ba31ac3fcb13d7016de7f57cc904fe1cb97c6ae98196e", "8e:cd:e6:88:4f:3d:87:b1:12:5b:a3:1a:c3:fc:b1:3d:70:16:de:7f:57:cc:90:4f:e1:cb:97:c6:ae:98:19:6e", "8ECDE6884F3D87B1125BA31AC3FCB13D7016DE7F57CC904FE1CB97C6AE98196E", ], ) def test_assert_fingerprint_in_cert_chain(node_cls, ssl_assert_fingerprint): with warnings.catch_warnings(record=True) as w: node = node_cls( NodeConfig( "https", "httpbin.org", 443, ssl_assert_fingerprint=ssl_assert_fingerprint, ) ) meta, _ = node.perform_request("GET", "/") assert meta.status == 200 assert w == [] @requires_ssl_assert_fingerprint_in_chain @pytest.mark.parametrize("node_cls", [Urllib3HttpNode, RequestsHttpNode]) def test_assert_fingerprint_in_cert_chain_failure(node_cls): node = node_cls( NodeConfig( "https", "httpbin.org", 443, ssl_assert_fingerprint="0" * 64, ) ) with pytest.raises(TlsError) as e: node.perform_request("GET", "/") err = str(e.value) assert "Fingerprints did not match." in err # This is the bad value we "expected" assert ( 'Expected "0000000000000000000000000000000000000000000000000000000000000000",' in err ) # This is the root CA for httpbin.org with a leading comma to denote more than one cert was listed. assert ', "8ecde6884f3d87b1125ba31ac3fcb13d7016de7f57cc904fe1cb97c6ae98196e"' in err elastic-transport-python-8.17.1/tests/test_client_utils.py000066400000000000000000000201441476450415400240710ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from platform import python_version import pytest from elastic_transport import Urllib3HttpNode, __version__ from elastic_transport.client_utils import ( basic_auth_to_header, client_meta_version, create_user_agent, parse_cloud_id, url_to_node_config, ) def test_create_user_agent(): assert create_user_agent( "enterprise-search-python", "7.10.0" ) == "enterprise-search-python/7.10.0 (Python/{}; elastic-transport/{})".format( python_version(), __version__, ) @pytest.mark.parametrize( ["version", "meta_version"], [ ("7.10.0", "7.10.0"), ("7.10.0-alpha1", "7.10.0p"), ("3.9.0b1", "3.9.0p"), ("3.9.pre1", "3.9p"), ("3.7.4.post1", "3.7.4"), ("3.7.4.post", "3.7.4"), ], ) def test_client_meta_version(version, meta_version): assert client_meta_version(version) == meta_version def test_parse_cloud_id(): cloud_id = parse_cloud_id( "cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVk" "MWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==" ) assert cloud_id.cluster_name == "cluster" assert cloud_id.es_address == ( "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io", 443, ) assert cloud_id.kibana_address == ( "4fa8821e75634032bed1cf22110e2f96.us-east-1.aws.found.io", 443, ) @pytest.mark.parametrize( ["cloud_id", "port"], [ ( ":dXMtZWFzdC0xLmF3cy5mb3VuZC5pbzo5MjQzJDRmYTg4MjFlNzU2MzQwMzJiZ" "WQxY2YyMjExMGUyZjk3JDRmYTg4MjFlNzU2MzQwMzJiZWQxY2YyMjExMGUyZjk2", 9243, ), ( ":dXMtZWFzdC0xLmF3cy5mb3VuZC5pbzo0NDMkNGZhODgyMWU3NTYzNDAzMmJlZD" "FjZjIyMTEwZTJmOTckNGZhODgyMWU3NTYzNDAzMmJlZDFjZjIyMTEwZTJmOTY=", 443, ), ], ) def test_parse_cloud_id_ports(cloud_id, port): cloud_id = parse_cloud_id(cloud_id) assert cloud_id.cluster_name == "" assert cloud_id.es_address == ( "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io", port, ) assert cloud_id.kibana_address == ( "4fa8821e75634032bed1cf22110e2f96.us-east-1.aws.found.io", port, ) @pytest.mark.parametrize( "cloud_id", [ "cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ=", "cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Nw==", ], ) def test_parse_cloud_id_no_kibana(cloud_id): cloud_id = parse_cloud_id(cloud_id) assert cloud_id.cluster_name == "cluster" assert cloud_id.es_address == ( "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io", 443, ) assert cloud_id.kibana_address is None @pytest.mark.parametrize( "cloud_id", [ "cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbzo0NDMkJA==", "cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbzo0NDM=", ], ) def test_parse_cloud_id_no_es(cloud_id): cloud_id = parse_cloud_id(cloud_id) assert cloud_id.cluster_name == "cluster" assert cloud_id.es_address is None assert cloud_id.kibana_address is None @pytest.mark.parametrize( "cloud_id", [ "cluster:", "dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ=", "cluster:ā", ], ) def test_invalid_cloud_id(cloud_id): with pytest.raises(ValueError) as e: parse_cloud_id(cloud_id) assert str(e.value) == "Cloud ID is not properly formatted" @pytest.mark.parametrize( ["url", "node_base_url", "path_prefix"], [ ("https://localhost", "https://localhost:443", ""), ("http://localhost:3002", "http://localhost:3002", ""), ("http://127.0.0.1:3002", "http://127.0.0.1:3002", ""), ("http://127.0.0.1:3002/", "http://127.0.0.1:3002", ""), ( "http://127.0.0.1:3002/path-prefix", "http://127.0.0.1:3002/path-prefix", "/path-prefix", ), ( "http://localhost:3002/url-prefix/", "http://localhost:3002/url-prefix", "/url-prefix", ), ( "https://localhost/url-prefix", "https://localhost:443/url-prefix", "/url-prefix", ), ("http://[::1]:3002/url-prefix", "http://[::1]:3002/url-prefix", "/url-prefix"), ("https://[::1]:0/", "https://[::1]:0", ""), ], ) def test_url_to_node_config(url, node_base_url, path_prefix): node_config = url_to_node_config(url) assert Urllib3HttpNode(node_config).base_url == node_base_url assert "[" not in node_config.host assert isinstance(node_config.port, int) assert node_config.path_prefix == path_prefix assert url.lower().startswith(node_config.scheme) @pytest.mark.parametrize( "url", [ "localhost:0", "[::1]:3002/url-prefix", "localhost", "localhost/", "localhost:3", "[::1]/url-prefix/", "[::1]", "[::1]:3002", "http://localhost", "localhost/url-prefix/", "localhost:3002/url-prefix", "http://localhost/url-prefix", ], ) def test_url_to_node_config_error_missing_component(url): with pytest.raises(ValueError) as e: url_to_node_config(url) assert ( str(e.value) == "URL must include a 'scheme', 'host', and 'port' component (ie 'https://localhost:9200')" ) @pytest.mark.parametrize( ["url", "port"], [ ("http://127.0.0.1", 80), ("http://[::1]", 80), ("HTTPS://localhost", 443), ("https://localhost/url-prefix", 443), ], ) def test_url_to_node_config_use_default_ports_for_scheme(url, port): node_config = url_to_node_config(url, use_default_ports_for_scheme=True) assert node_config.port == port def test_url_with_auth_into_authorization(): node_config = url_to_node_config("http://localhost:9200") assert node_config.headers == {} node_config = url_to_node_config("http://@localhost:9200") assert node_config.headers == {} node_config = url_to_node_config("http://user:pass@localhost:9200") assert node_config.headers == {"Authorization": "Basic dXNlcjpwYXNz"} node_config = url_to_node_config("http://user:@localhost:9200") assert node_config.headers == {"Authorization": "Basic dXNlcjo="} node_config = url_to_node_config("http://user@localhost:9200") assert node_config.headers == {"Authorization": "Basic dXNlcjo="} node_config = url_to_node_config("http://me@example.com:password@localhost:9200") assert node_config.headers == { "Authorization": "Basic bWVAZXhhbXBsZS5jb206cGFzc3dvcmQ=" } # ensure username and password are passed to basic auth unmodified basic_auth = basic_auth_to_header(("user:@", "@password")) node_config = url_to_node_config("http://user:@:@password@localhost:9200") assert node_config.headers == {"Authorization": basic_auth} node_config = url_to_node_config("http://user%3A%40:%40password@localhost:9200") assert node_config.headers == {"Authorization": basic_auth} @pytest.mark.parametrize( "basic_auth", ["", b"", ("",), ("", 1), (1, ""), ["", ""], False, object()] ) def test_basic_auth_errors(basic_auth): with pytest.raises(ValueError) as e: basic_auth_to_header(basic_auth) assert ( str(e.value) == "'basic_auth' must be a 2-tuple of str/bytes (username, password)" ) elastic-transport-python-8.17.1/tests/test_exceptions.py000066400000000000000000000041601476450415400235540ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest from elastic_transport import ApiError, ApiResponseMeta, TransportError def test_exception_repr_and_str(): e = TransportError({"errors": [{"status": 500}]}) assert repr(e) == "TransportError({'errors': [{'status': 500}]})" assert str(e) == "{'errors': [{'status': 500}]}" e = TransportError("error", errors=(ValueError("value error"),)) assert repr(e) == "TransportError('error', errors={!r})".format( e.errors, ) assert str(e) == "error" def test_api_error_status_repr(): e = ApiError( {"errors": [{"status": 500}]}, body={}, meta=ApiResponseMeta( status=500, http_version="1.1", headers={}, duration=0.0, node=None ), ) assert ( repr(e) == "ApiError({'errors': [{'status': 500}]}, meta=ApiResponseMeta(status=500, http_version='1.1', headers={}, duration=0.0, node=None), body={})" ) assert str(e) == "[500] {'errors': [{'status': 500}]}" def test_api_error_is_not_transport_error(): with pytest.raises(ApiError): try: raise ApiError("", None, None) except TransportError: pass def test_transport_error_is_not_api_error(): with pytest.raises(TransportError): try: raise TransportError( "", ) except ApiError: pass elastic-transport-python-8.17.1/tests/test_httpbin.py000066400000000000000000000107731476450415400230520ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import dataclasses import json import pytest from elastic_transport import Transport from elastic_transport._node._base import DEFAULT_USER_AGENT from elastic_transport._transport import NODE_CLASS_NAMES @pytest.mark.parametrize("node_class", ["urllib3", "requests"]) def test_simple_request(node_class, httpbin_node_config): t = Transport([httpbin_node_config], node_class=node_class) resp, data = t.perform_request( "GET", "/anything?key[]=1&key[]=2&q1&q2=", headers={"Custom": "headeR", "content-type": "application/json"}, body={"JSON": "body"}, ) assert resp.status == 200 assert data["method"] == "GET" assert data["url"] == "https://httpbin.org/anything?key[]=1&key[]=2&q1&q2=" # httpbin makes no-value query params into '' assert data["args"] == { "key[]": ["1", "2"], "q1": "", "q2": "", } assert data["data"] == '{"JSON":"body"}' assert data["json"] == {"JSON": "body"} request_headers = { "Content-Type": "application/json", "Content-Length": "15", "Custom": "headeR", "Host": "httpbin.org", } assert all(v == data["headers"][k] for k, v in request_headers.items()) @pytest.mark.parametrize("node_class", ["urllib3", "requests"]) def test_node(node_class, httpbin_node_config): def new_node(**kwargs): return NODE_CLASS_NAMES[node_class]( dataclasses.replace(httpbin_node_config, **kwargs) ) node = new_node() resp, data = node.perform_request("GET", "/anything") assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Accept-Encoding": "identity", "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "GET", "url": "https://httpbin.org/anything", } node = new_node(http_compress=True) resp, data = node.perform_request("GET", "/anything") assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Accept-Encoding": "gzip", "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "GET", "url": "https://httpbin.org/anything", } resp, data = node.perform_request("GET", "/anything", body=b"hello, world!") assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Accept-Encoding": "gzip", "Content-Encoding": "gzip", "Content-Length": "33", "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "GET", "url": "https://httpbin.org/anything", } resp, data = node.perform_request( "POST", "/anything", body=json.dumps({"key": "value"}).encode("utf-8"), headers={"content-type": "application/json"}, ) assert resp.status == 200 parsed = parse_httpbin(data) assert parsed == { "headers": { "Accept-Encoding": "gzip", "Content-Encoding": "gzip", "Content-Length": "36", "Content-Type": "application/json", "Host": "httpbin.org", "User-Agent": DEFAULT_USER_AGENT, }, "method": "POST", "url": "https://httpbin.org/anything", } def parse_httpbin(value): """Parses a response from httpbin.org/anything by stripping all the variable things""" if isinstance(value, bytes): value = json.loads(value) else: value = value.copy() value.pop("origin", None) value.pop("data", None) value["headers"].pop("X-Amzn-Trace-Id", None) value = {k: v for k, v in value.items() if v} return value elastic-transport-python-8.17.1/tests/test_httpserver.py000066400000000000000000000023721476450415400236040ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import warnings import pytest from elastic_transport import Transport @pytest.mark.parametrize("node_class", ["urllib3", "requests"]) def test_simple_request(node_class, https_server_ip_node_config): with warnings.catch_warnings(): warnings.simplefilter("error") t = Transport([https_server_ip_node_config], node_class=node_class) resp, data = t.perform_request("GET", "/foobar") assert resp.status == 200 assert data == {"foo": "bar"} elastic-transport-python-8.17.1/tests/test_logging.py000066400000000000000000000110201476450415400230120ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import io import logging import pytest from elastic_transport import ( AiohttpHttpNode, ConnectionError, HttpHeaders, RequestsHttpNode, Urllib3HttpNode, debug_logging, ) from elastic_transport._compat import await_if_coro from elastic_transport._node._base import DEFAULT_USER_AGENT node_class = pytest.mark.parametrize( "node_class", [Urllib3HttpNode, RequestsHttpNode, AiohttpHttpNode] ) @node_class @pytest.mark.asyncio async def test_debug_logging(node_class, httpbin_node_config): debug_logging() stream = io.StringIO() logging.getLogger("elastic_transport.node").addHandler( logging.StreamHandler(stream) ) node = node_class(httpbin_node_config) await await_if_coro( node.perform_request( "GET", "/anything", body=b'{"key":"value"}', headers=HttpHeaders({"Content-Type": "application/json"}), ) ) print(node_class) print(stream.getvalue()) lines = stream.getvalue().split("\n") print(lines) for line in [ "> GET /anything HTTP/1.1", "> Connection: keep-alive", "> Content-Type: application/json", f"> User-Agent: {DEFAULT_USER_AGENT}", '> {"key":"value"}', "< HTTP/1.1 200 OK", "< Access-Control-Allow-Credentials: true", "< Access-Control-Allow-Origin: *", "< Connection: close", "< Content-Type: application/json", "< {", ' "args": {}, ', ' "data": "{\\"key\\":\\"value\\"}", ', ' "files": {}, ', ' "form": {}, ', ' "headers": {', ' "Content-Type": "application/json", ', ' "Host": "httpbin.org", ', f' "User-Agent": "{DEFAULT_USER_AGENT}", ', " }, ", ' "json": {', ' "key": "value"', " }, ", ' "method": "GET", ', ' "url": "https://httpbin.org/anything"', "}", ]: assert line in lines @node_class @pytest.mark.asyncio async def test_debug_logging_uncompressed_body(httpbin_node_config, node_class): debug_logging() stream = io.StringIO() logging.getLogger("elastic_transport.node").addHandler( logging.StreamHandler(stream) ) node = node_class(httpbin_node_config.replace(http_compress=True)) await await_if_coro( node.perform_request( "GET", "/anything", body=b'{"key":"value"}', headers=HttpHeaders({"Content-Type": "application/json"}), ) ) lines = stream.getvalue().split("\n") print(lines) assert '> {"key":"value"}' in lines @node_class @pytest.mark.asyncio async def test_debug_logging_no_body(httpbin_node_config, node_class): debug_logging() stream = io.StringIO() logging.getLogger("elastic_transport.node").addHandler( logging.StreamHandler(stream) ) node = node_class(httpbin_node_config) await await_if_coro( node.perform_request( "HEAD", "/anything", ) ) lines = stream.getvalue().split("\n")[:-3] assert "> HEAD /anything HTTP/1.1" in lines @node_class @pytest.mark.asyncio async def test_debug_logging_error(httpbin_node_config, node_class): debug_logging() stream = io.StringIO() logging.getLogger("elastic_transport.node").addHandler( logging.StreamHandler(stream) ) node = node_class(httpbin_node_config.replace(host="not.a.valid.host")) with pytest.raises(ConnectionError): await await_if_coro( node.perform_request( "HEAD", "/anything", ) ) lines = stream.getvalue().split("\n")[:-3] assert "> HEAD /anything HTTP/?.?" in lines assert all(not line.startswith("<") for line in lines) elastic-transport-python-8.17.1/tests/test_models.py000066400000000000000000000056341476450415400226650ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import dataclasses import pytest from elastic_transport import HttpHeaders, NodeConfig def test_empty_node_config(): config = NodeConfig(scheme="https", host="localhost", port=9200) assert dataclasses.asdict(config) == { "ca_certs": None, "client_cert": None, "client_key": None, "connections_per_node": 10, "headers": {}, "host": "localhost", "http_compress": False, "path_prefix": "", "port": 9200, "request_timeout": 10, "scheme": "https", "ssl_assert_fingerprint": None, "ssl_assert_hostname": None, "ssl_context": None, "ssl_show_warn": True, "ssl_version": None, "verify_certs": True, "_extras": {}, } # Default HttpHeaders should be empty and frozen assert len(config.headers) == 0 assert config.headers.frozen def test_headers_frozen(): headers = HttpHeaders() assert headers.frozen is False headers["key"] = "value" headers.pop("Key") headers["key"] = "value" assert headers.freeze() is headers assert headers.frozen is True with pytest.raises(ValueError) as e: headers["key"] = "value" assert str(e.value) == "Can't modify headers that have been frozen" with pytest.raises(ValueError) as e: headers.pop("key") assert str(e.value) == "Can't modify headers that have been frozen" assert len(headers) == 1 assert headers == {"key": "value"} assert headers.copy() is not headers assert headers.copy().frozen is False @pytest.mark.parametrize( ["headers", "string"], [ ({"field": "value"}, "{'field': 'value'}"), ({"Authorization": "value"}, "{'Authorization': ''}"), ({"authorization": "Basic"}, "{'authorization': ''}"), ({"authorization": "Basic abc"}, "{'authorization': 'Basic '}"), ({"authorization": "ApiKey abc"}, "{'authorization': 'ApiKey '}"), ({"authorization": "Bearer abc"}, "{'authorization': 'Bearer '}"), ], ) def test_headers_hide_auth(headers, string): assert repr(HttpHeaders(headers)) == string elastic-transport-python-8.17.1/tests/test_node_pool.py000066400000000000000000000200111476450415400233420ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import random import threading import time import pytest from elastic_transport import NodeConfig, NodePool, Urllib3HttpNode def test_node_pool_repr(): node_configs = [NodeConfig("http", "localhost", x) for x in range(5)] random.shuffle(node_configs) pool = NodePool(node_configs, node_class=Urllib3HttpNode) assert repr(pool) == "" def test_node_pool_empty_error(): with pytest.raises(ValueError) as e: NodePool([], node_class=Urllib3HttpNode) assert str(e.value) == "Must specify at least one NodeConfig" def test_node_pool_duplicate_node_configs(): node_config = NodeConfig("http", "localhost", 80) with pytest.raises(ValueError) as e: NodePool([node_config, node_config], node_class=Urllib3HttpNode) assert str(e.value) == "Cannot use duplicate NodeConfigs within a NodePool" def test_node_pool_get(): node_config = NodeConfig("http", "localhost", 80) pool = NodePool([node_config], node_class=Urllib3HttpNode) assert pool.get().config is node_config def test_node_pool_remove_seed_node(): node_config = NodeConfig("http", "localhost", 80) pool = NodePool([node_config], node_class=Urllib3HttpNode) pool.remove(node_config) # Calling .remove() on a seed node is a no-op assert len(pool._removed_nodes) == 0 def test_node_pool_add_and_remove_non_seed_node(): node_config1 = NodeConfig("http", "localhost", 80) node_config2 = NodeConfig("http", "localhost", 81) pool = NodePool([node_config1], node_class=Urllib3HttpNode) pool.add(node_config2) assert any(pool.get().config is node_config2 for _ in range(10)) pool.remove(node_config2) assert len(pool._removed_nodes) == 1 # We never return a 'removed' node assert all(pool.get().config is node_config1 for _ in range(10)) # We add it back, now we should .get() the node again. pool.add(node_config2) assert any(pool.get().config is node_config2 for _ in range(10)) def test_added_node_is_used_first(): node_config1 = NodeConfig("http", "localhost", 80) node_config2 = NodeConfig("http", "localhost", 81) pool = NodePool([node_config1], node_class=Urllib3HttpNode) node1 = pool.get() pool.mark_dead(node1) pool.add(node_config2) assert pool.get().config is node_config2 def test_round_robin_selector(): node_configs = [NodeConfig("http", "localhost", x) for x in range(5)] random.shuffle(node_configs) pool = NodePool( node_configs, node_class=Urllib3HttpNode, node_selector_class="round_robin" ) get_node_configs = [pool.get() for _ in node_configs] for node_config in get_node_configs: assert pool.get() is node_config @pytest.mark.parametrize( "node_configs", [ [NodeConfig("http", "localhost", 80)], [NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81)], ], ) def test_all_dead_nodes_still_gets_node(node_configs): pool = NodePool(node_configs, node_class=Urllib3HttpNode) for _ in node_configs: pool.mark_dead(pool.get()) assert len(pool._alive_nodes) == 0 node = pool.get() assert node.config in node_configs assert len(pool._alive_nodes) < 2 def test_unknown_selector_class(): with pytest.raises(ValueError) as e: NodePool( [NodeConfig("http", "localhost", 80)], node_class=Urllib3HttpNode, node_selector_class="unknown", ) assert str(e.value) == ( "Unknown option for selector_class: 'unknown'. " "Available options are: 'random', 'round_robin'" ) def test_disable_randomize_nodes(): node_configs = [NodeConfig("http", "localhost", x) for x in range(100)] pool = NodePool(node_configs, node_class=Urllib3HttpNode, randomize_nodes=False) assert [pool.get().config for _ in node_configs] == node_configs def test_nodes_randomized_by_default(): node_configs = [NodeConfig("http", "localhost", x) for x in range(100)] pool = NodePool(node_configs, node_class=Urllib3HttpNode) assert [pool.get().config for _ in node_configs] != node_configs def test_dead_nodes_are_skipped(): node_configs = [NodeConfig("http", "localhost", x) for x in range(2)] pool = NodePool(node_configs, node_class=Urllib3HttpNode) dead_node = pool.get() pool.mark_dead(dead_node) alive_node = pool.get() assert dead_node.config != alive_node.config assert all([pool.get().config == alive_node.config for _ in range(10)]) def test_dead_node_backoff_calculation(): node_configs = [NodeConfig("http", "localhost", 9200)] pool = NodePool( node_configs, node_class=Urllib3HttpNode, dead_node_backoff_factor=0.5, max_dead_node_backoff=3.5, ) node = pool.get() pool.mark_dead(node, _now=0) assert pool._dead_consecutive_failures == {node.config: 1} assert pool._dead_nodes.queue == [(0.5, node)] assert pool.get() is node pool.mark_dead(node, _now=0) assert pool._dead_consecutive_failures == {node.config: 2} assert pool._dead_nodes.queue == [(1.0, node)] assert pool.get() is node pool.mark_dead(node, _now=0) assert pool._dead_consecutive_failures == {node.config: 3} assert pool._dead_nodes.queue == [(2.0, node)] assert pool.get() is node pool.mark_dead(node, _now=0) assert pool._dead_consecutive_failures == {node.config: 4} assert pool._dead_nodes.queue == [(3.5, node)] assert pool.get() is node pool.mark_dead(node, _now=0) pool._dead_consecutive_failures = {node.config: 13292} assert pool._dead_nodes.queue == [(3.5, node)] assert pool.get() is node pool.mark_live(node) assert pool._dead_consecutive_failures == {} assert pool._dead_nodes.queue == [] def test_add_node_after_sniffing(): node_configs = [NodeConfig("http", "localhost", 9200)] pool = NodePool( node_configs, node_class=Urllib3HttpNode, ) # Initial node is marked as dead node = pool.get() pool.mark_dead(node) new_node_config = NodeConfig("http", "localhost", 9201) pool.add(new_node_config) # Internal flag is updated properly assert pool._all_nodes_len_1 is False # We get the new node instead of the old one new_node = pool.get() assert new_node.config == new_node_config # The old node is still on timeout so we should only get the new one. for _ in range(10): assert pool.get() is new_node @pytest.mark.parametrize("pool_size", [1, 8]) def test_threading_test(pool_size): pool = NodePool( [NodeConfig("http", "localhost", x) for x in range(pool_size)], node_class=Urllib3HttpNode, ) start = time.time() class ThreadTest(threading.Thread): def __init__(self): super().__init__() self.nodes_gotten = 0 def run(self) -> None: nonlocal pool while time.time() < start + 2: node = pool.get() self.nodes_gotten += 1 if random.random() > 0.9: pool.mark_dead(node) else: pool.mark_live(node) threads = [ThreadTest() for _ in range(pool_size * 2)] [thread.start() for thread in threads] [thread.join() for thread in threads] assert sum(thread.nodes_gotten for thread in threads) >= 10000 elastic-transport-python-8.17.1/tests/test_otel.py000066400000000000000000000070251476450415400223410ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from opentelemetry.sdk.trace import TracerProvider, export from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from elastic_transport import JsonSerializer from elastic_transport._otel import OpenTelemetrySpan def setup_tracing(): tracer_provider = TracerProvider() memory_exporter = InMemorySpanExporter() span_processor = export.SimpleSpanProcessor(memory_exporter) tracer_provider.add_span_processor(span_processor) tracer = tracer_provider.get_tracer(__name__) return tracer, memory_exporter def test_no_span(): # With telemetry disabled, those calls should not raise span = OpenTelemetrySpan(None) span.set_db_statement(JsonSerializer().dumps({"timeout": "1m"})) span.set_node_metadata( "localhost", 9200, "http://localhost:9200/", "_ml/anomaly_detectors/my-job/_open", ) span.set_elastic_cloud_metadata( { "X-Found-Handling-Cluster": "e9106fc68e3044f0b1475b04bf4ffd5f", "X-Found-Handling-Instance": "instance-0000000001", } ) def test_detailed_span(): tracer, memory_exporter = setup_tracing() with tracer.start_as_current_span("ml.open_job") as otel_span: span = OpenTelemetrySpan( otel_span, endpoint_id="my-job/_open", body_strategy="omit", ) span.set_db_statement(JsonSerializer().dumps({"timeout": "1m"})) span.set_node_metadata( "localhost", 9200, "http://localhost:9200/", "_ml/anomaly_detectors/my-job/_open", ) span.set_elastic_cloud_metadata( { "X-Found-Handling-Cluster": "e9106fc68e3044f0b1475b04bf4ffd5f", "X-Found-Handling-Instance": "instance-0000000001", } ) spans = memory_exporter.get_finished_spans() assert len(spans) == 1 assert spans[0].name == "ml.open_job" assert spans[0].attributes == { "url.full": "http://localhost:9200/_ml/anomaly_detectors/my-job/_open", "server.address": "localhost", "server.port": 9200, "db.elasticsearch.cluster.name": "e9106fc68e3044f0b1475b04bf4ffd5f", "db.elasticsearch.node.name": "instance-0000000001", } def test_db_statement(): tracer, memory_exporter = setup_tracing() with tracer.start_as_current_span("search") as otel_span: span = OpenTelemetrySpan(otel_span, endpoint_id="search", body_strategy="raw") span.set_db_statement(JsonSerializer().dumps({"query": {"match_all": {}}})) spans = memory_exporter.get_finished_spans() assert len(spans) == 1 assert spans[0].name == "search" assert spans[0].attributes == { "db.statement": '{"query":{"match_all":{}}}', } elastic-transport-python-8.17.1/tests/test_package.py000066400000000000000000000027171476450415400227740ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import elastic_transport from elastic_transport import client_utils modules = pytest.mark.parametrize("module", [elastic_transport, client_utils]) @modules def test__all__sorted(module): module_all = module.__all__.copy() # Optional dependencies are added at the end if "OrjsonSerializer" in module_all: module_all.remove("OrjsonSerializer") assert module_all == sorted(module_all) @modules def test__all__is_importable(module): assert {attr for attr in module.__all__ if hasattr(module, attr)} == set( module.__all__ ) def test_module_rewritten(): assert repr(elastic_transport.Transport) == "" elastic-transport-python-8.17.1/tests/test_response.py000066400000000000000000000110471476450415400232330ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pickle import pytest from elastic_transport import ( ApiResponseMeta, BinaryApiResponse, HeadApiResponse, HttpHeaders, ListApiResponse, ObjectApiResponse, TextApiResponse, ) meta = ApiResponseMeta( status=200, http_version="1.1", headers=HttpHeaders(), duration=0, node=None ) @pytest.mark.parametrize( "response_cls", [TextApiResponse, BinaryApiResponse, ObjectApiResponse, ListApiResponse], ) def test_response_meta(response_cls): resp = response_cls(meta=meta, body=None) assert resp.meta is meta assert resp == resp assert resp.body == resp assert resp == resp.body assert not resp != resp assert not (resp != resp.body) def test_head_response(): resp = HeadApiResponse(meta=meta) assert resp assert resp.body is True assert bool(resp) is True assert resp.meta is meta assert repr(resp) == "HeadApiResponse(True)" def test_text_response(): resp = TextApiResponse(body="Hello, world", meta=meta) assert resp.body == "Hello, world" assert len(resp) == 12 assert resp.lower() == "hello, world" assert list(resp) == ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d"] assert repr(resp) == "TextApiResponse('Hello, world')" def test_binary_response(): resp = BinaryApiResponse(body=b"Hello, world", meta=meta) assert resp.body == b"Hello, world" assert len(resp) == 12 assert resp[0] == 72 assert resp[:2] == b"He" assert resp.lower() == b"hello, world" assert resp.decode() == "Hello, world" assert list(resp) == [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100] assert repr(resp) == "BinaryApiResponse(b'Hello, world')" def test_list_response(): resp = ListApiResponse(body=[1, 2, 3], meta=meta) assert list(resp) == [1, 2, 3] assert resp.body == [1, 2, 3] assert resp[1] == 2 assert repr(resp) == "ListApiResponse([1, 2, 3])" def test_list_object_response(): resp = ObjectApiResponse(body={"k1": 1, "k2": 2}, meta=meta) assert set(resp.keys()) == {"k1", "k2"} assert resp["k2"] == 2 assert resp.body == {"k1": 1, "k2": 2} assert repr(resp) == "ObjectApiResponse({'k1': 1, 'k2': 2})" @pytest.mark.parametrize( "resp_cls", [ObjectApiResponse, ListApiResponse, TextApiResponse, BinaryApiResponse] ) @pytest.mark.parametrize( ["args", "kwargs"], [ ((), {}), ((1,), {}), ((1,), {"raw": 1}), ((1,), {"body": 1}), ((1,), {"meta": 1}), ((), {"raw": 1, "body": 1}), ((), {"raw": 1, "body": 1, "meta": 1}), ((1,), {"raw": 1, "meta": 1}), ((1,), {"meta": 1, "body": 1}), ((1, 1), {"meta": 1, "body": 1}), ((), {"meta": 1, "body": 1, "unk": 1}), ], ) def test_constructor_type_errors(resp_cls, args, kwargs): with pytest.raises(TypeError) as e: resp_cls(*args, **kwargs) assert str(e.value) == "Must pass 'meta' and 'body' to ApiResponse" def test_constructor_allowed(): resp = HeadApiResponse(meta) resp = HeadApiResponse(meta=meta) resp = ObjectApiResponse({}, meta) assert resp == {} resp = ObjectApiResponse(meta=meta, raw={}) assert resp == {} resp = ObjectApiResponse(meta=meta, raw={}, body_cls=int) assert resp == {} resp = ObjectApiResponse(meta=meta, body={}, body_cls=int) assert resp == {} @pytest.mark.parametrize( "response_cls, body", [ (TextApiResponse, "Hello World"), (BinaryApiResponse, b"Hello World"), (ObjectApiResponse, {"Hello": "World"}), (ListApiResponse, ["Hello", "World"]), ], ) def test_pickle(response_cls, body): resp = response_cls(meta=meta, body=body) pickled_resp = pickle.loads(pickle.dumps(resp)) assert pickled_resp == resp assert pickled_resp.meta == resp.meta elastic-transport-python-8.17.1/tests/test_serializer.py000066400000000000000000000137661476450415400235600ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import uuid from datetime import date from decimal import Decimal import pytest from elastic_transport import ( JsonSerializer, NdjsonSerializer, OrjsonSerializer, SerializationError, SerializerCollection, TextSerializer, ) from elastic_transport._serializer import DEFAULT_SERIALIZERS serializers = SerializerCollection(DEFAULT_SERIALIZERS) @pytest.fixture(params=[JsonSerializer, OrjsonSerializer]) def json_serializer(request: pytest.FixtureRequest): yield request.param() def test_date_serialization(json_serializer): assert b'{"d":"2010-10-01"}' == json_serializer.dumps({"d": date(2010, 10, 1)}) def test_decimal_serialization(json_serializer): assert b'{"d":3.8}' == json_serializer.dumps({"d": Decimal("3.8")}) def test_uuid_serialization(json_serializer): assert b'{"d":"00000000-0000-0000-0000-000000000003"}' == json_serializer.dumps( {"d": uuid.UUID("00000000-0000-0000-0000-000000000003")} ) def test_serializes_nan(): assert b'{"d":NaN}' == JsonSerializer().dumps({"d": float("NaN")}) # NaN is invalid JSON, and orjson silently converts it to null assert b'{"d":null}' == OrjsonSerializer().dumps({"d": float("NaN")}) def test_raises_serialization_error_on_dump_error(json_serializer): with pytest.raises(SerializationError): json_serializer.dumps(object()) with pytest.raises(SerializationError): TextSerializer().dumps({}) def test_raises_serialization_error_on_load_error(json_serializer): with pytest.raises(SerializationError): json_serializer.loads(object()) with pytest.raises(SerializationError): json_serializer.loads(b"{{") def test_json_unicode_is_handled(json_serializer): assert ( json_serializer.dumps({"你好": "你好"}) == b'{"\xe4\xbd\xa0\xe5\xa5\xbd":"\xe4\xbd\xa0\xe5\xa5\xbd"}' ) assert json_serializer.loads( b'{"\xe4\xbd\xa0\xe5\xa5\xbd":"\xe4\xbd\xa0\xe5\xa5\xbd"}' ) == {"你好": "你好"} def test_text_unicode_is_handled(): text_serializer = TextSerializer() assert text_serializer.dumps("你好") == b"\xe4\xbd\xa0\xe5\xa5\xbd" assert text_serializer.loads(b"\xe4\xbd\xa0\xe5\xa5\xbd") == "你好" def test_json_unicode_surrogates_handled(): assert ( JsonSerializer().dumps({"key": "你好\uda6a"}) == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}' ) assert JsonSerializer().loads( b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}' ) == {"key": "你好\uda6a"} # orjson is strict about UTF-8 with pytest.raises(SerializationError): OrjsonSerializer().dumps({"key": "你好\uda6a"}) with pytest.raises(SerializationError): OrjsonSerializer().loads(b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}') def test_text_unicode_surrogates_handled(json_serializer): text_serializer = TextSerializer() assert ( text_serializer.dumps("你好\uda6a") == b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" ) assert ( text_serializer.loads(b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa") == "你好\uda6a" ) def test_deserializes_json_by_default(): assert {"some": "data"} == serializers.loads(b'{"some":"data"}') def test_deserializes_text_with_correct_ct(): assert '{"some":"data"}' == serializers.loads(b'{"some":"data"}', "text/plain") assert '{"some":"data"}' == serializers.loads( b'{"some":"data"}', "text/plain; charset=whatever" ) def test_raises_serialization_error_on_unknown_mimetype(): with pytest.raises(SerializationError) as e: serializers.loads(b"{}", "fake/type") assert ( str(e.value) == "Unknown mimetype, not able to serialize or deserialize: fake/type" ) def test_raises_improperly_configured_when_default_mimetype_cannot_be_deserialized(): with pytest.raises(ValueError) as e: SerializerCollection({}) assert ( str(e.value) == "Must configure a serializer for the default mimetype 'application/json'" ) def test_text_asterisk_works_for_all_text_types(): assert serializers.loads(b"{}", "text/html") == "{}" assert serializers.dumps("{}", "text/html") == b"{}" @pytest.mark.parametrize("should_strip", [False, b"\n", b"\r\n"]) def test_ndjson_loads(should_strip): serializer = NdjsonSerializer() data = ( b'{"key":"value"}\n' b'{"number":0.1,"one":1}\n' b'{"list":[1,2,3]}\r\n' b'{"unicode":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}\r\n' ) if should_strip: data = data.strip(should_strip) data = serializer.loads(data) assert data == [ {"key": "value"}, {"number": 0.1, "one": 1}, {"list": [1, 2, 3]}, {"unicode": "你好\uda6a"}, ] def test_ndjson_dumps(): serializer = NdjsonSerializer() data = serializer.dumps( [ {"key": "value"}, {"number": 0.1, "one": 1}, {"list": [1, 2, 3]}, {"unicode": "你好\uda6a"}, '{"key:"value"}', b'{"bytes":"too"}', ] ) assert data == ( b'{"key":"value"}\n' b'{"number":0.1,"one":1}\n' b'{"list":[1,2,3]}\n' b'{"unicode":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}\n' b'{"key:"value"}\n' b'{"bytes":"too"}\n' ) elastic-transport-python-8.17.1/tests/test_transport.py000066400000000000000000000466641476450415400234460ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import random import re import ssl import threading import time import warnings from unittest import mock import pytest from elastic_transport import ( AiohttpHttpNode, ConnectionError, ConnectionTimeout, NodeConfig, RequestsHttpNode, SniffingError, SniffOptions, Transport, TransportError, TransportWarning, Urllib3HttpNode, ) from elastic_transport.client_utils import DEFAULT from tests.conftest import DummyNode def test_transport_close_node_pool(): t = Transport([NodeConfig("http", "localhost", 443)]) with mock.patch.object(t.node_pool.all()[0], "close") as node_close: t.close() node_close.assert_called_with() def test_request_with_custom_user_agent_header(): t = Transport( [NodeConfig("http", "localhost", 80)], node_class=DummyNode, meta_header=False ) t.perform_request("GET", "/", headers={"user-agent": "my-custom-value/1.2.3"}) assert 1 == len(t.node_pool.get().calls) assert { "body": None, "request_timeout": DEFAULT, "headers": {"user-agent": "my-custom-value/1.2.3"}, } == t.node_pool.get().calls[0][1] def test_body_gets_encoded_into_bytes(): t = Transport([NodeConfig("http", "localhost", 80)], node_class=DummyNode) t.perform_request( "GET", "/", headers={"Content-type": "application/json"}, body={"key": "你好"} ) calls = t.node_pool.get().calls assert 1 == len(calls) args, kwargs = calls[0] assert ("GET", "/") == args assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}' def test_body_bytes_get_passed_untouched(): t = Transport([NodeConfig("http", "localhost", 80)], node_class=DummyNode) body = b"\xe4\xbd\xa0\xe5\xa5\xbd" t.perform_request( "GET", "/", body=body, headers={"Content-Type": "application/json"} ) calls = t.node_pool.get().calls assert 1 == len(calls) args, kwargs = calls[0] assert ("GET", "/") == args assert kwargs["body"] == b"\xe4\xbd\xa0\xe5\xa5\xbd" def test_empty_response_with_content_type(): t = Transport( [ NodeConfig( "http", "localhost", 80, _extras={"body": b"", "headers": {"Content-Type": "application/json"}}, ) ], node_class=DummyNode, ) resp = t.perform_request("GET", "/", headers={"Accept": "application/json"}) # Empty body is deserialized as 'None' instead of an error. assert resp.meta.status == 200 assert resp.body is None def test_kwargs_passed_on_to_node_pool(): dt = object() t = Transport( [NodeConfig("http", "localhost", 80)], dead_node_backoff_factor=dt, max_dead_node_backoff=dt, ) assert dt is t.node_pool.dead_node_backoff_factor assert dt is t.node_pool.max_dead_node_backoff def test_request_will_fail_after_x_retries(): t = Transport( [ NodeConfig( "http", "localhost", 80, _extras={"exception": ConnectionError("abandon ship")}, ) ], node_class=DummyNode, max_retries=0, ) with pytest.raises(ConnectionError) as e: t.perform_request("GET", "/") assert 1 == len(t.node_pool.get().calls) assert len(e.value.errors) == 0 # max_retries=3 t = Transport( [ NodeConfig( "http", "localhost", 80, _extras={"exception": ConnectionError("abandon ship")}, ) ], node_class=DummyNode, max_retries=3, ) with pytest.raises(ConnectionError) as e: t.perform_request("GET", "/") assert 4 == len(t.node_pool.get().calls) assert len(e.value.errors) == 3 assert all(isinstance(error, ConnectionError) for error in e.value.errors) # max_retries=2 in perform_request() with pytest.raises(ConnectionError) as e: t.perform_request("GET", "/", max_retries=2) assert 7 == len(t.node_pool.get().calls) assert len(e.value.errors) == 2 assert all(isinstance(error, ConnectionError) for error in e.value.errors) @pytest.mark.parametrize("retry_on_timeout", [True, False]) def test_retry_on_timeout(retry_on_timeout): t = Transport( [ NodeConfig( "http", "localhost", 80, _extras={"exception": ConnectionTimeout("abandon ship")}, ), NodeConfig( "http", "localhost", 81, _extras={"exception": ConnectionError("error!")}, ), ], node_class=DummyNode, max_retries=1, retry_on_timeout=retry_on_timeout, randomize_nodes_in_pool=False, ) if retry_on_timeout: with pytest.raises(ConnectionError) as e: t.perform_request("GET", "/") assert len(e.value.errors) == 1 assert isinstance(e.value.errors[0], ConnectionTimeout) else: with pytest.raises(ConnectionTimeout) as e: t.perform_request("GET", "/") assert len(e.value.errors) == 0 def test_retry_on_status(): t = Transport( [ NodeConfig("http", "localhost", 80, _extras={"status": 404}), NodeConfig( "http", "localhost", 81, _extras={"status": 401}, ), NodeConfig( "http", "localhost", 82, _extras={"status": 403}, ), NodeConfig( "http", "localhost", 83, _extras={"status": 555}, ), ], node_class=DummyNode, node_selector_class="round_robin", retry_on_status=(401, 403, 404), randomize_nodes_in_pool=False, max_retries=5, ) meta, _ = t.perform_request("GET", "/") assert meta.status == 555 # Assert that every node is called once node_calls = [len(node.calls) for node in t.node_pool.all()] assert node_calls == [ 1, 1, 1, 1, ] def test_failed_connection_will_be_marked_as_dead(): t = Transport( [ NodeConfig( "http", "localhost", 80, _extras={"exception": ConnectionError("abandon ship")}, ), NodeConfig( "http", "localhost", 81, _extras={"exception": ConnectionError("abandon ship")}, ), ], max_retries=3, node_class=DummyNode, ) with pytest.raises(ConnectionError) as e: t.perform_request("GET", "/") assert 0 == len(t.node_pool._alive_nodes) assert 2 == len(t.node_pool._dead_nodes.queue) assert len(e.value.errors) == 3 assert all(isinstance(error, ConnectionError) for error in e.value.errors) def test_resurrected_connection_will_be_marked_as_live_on_success(): for method in ("GET", "HEAD"): t = Transport( [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81), ], node_class=DummyNode, ) node1 = t.node_pool.get() node2 = t.node_pool.get() t.node_pool.mark_dead(node1) t.node_pool.mark_dead(node2) t.perform_request(method, "/") assert 1 == len(t.node_pool._alive_nodes) assert 1 == len(t.node_pool._dead_consecutive_failures) assert 1 == len(t.node_pool._dead_nodes.queue) def test_sniff_on_node_failure_error_doesnt_raise(): t = Transport( [ NodeConfig("http", "localhost", 80, _extras={"status": 502}), NodeConfig("http", "localhost", 81), ], max_retries=1, retry_on_status=(502,), node_class=DummyNode, randomize_nodes_in_pool=False, ) bad_node = t.node_pool._all_nodes[NodeConfig("http", "localhost", 80)] with mock.patch.object(t, "sniff") as sniff, mock.patch.object( t.node_pool, "mark_dead" ) as mark_dead: sniff.side_effect = TransportError("sniffing error!") t.perform_request("GET", "/") mark_dead.assert_called_with(bad_node) def test_node_class_as_string(): t = Transport([NodeConfig("http", "localhost", 80)], node_class="urllib3") assert isinstance(t.node_pool.get(), Urllib3HttpNode) t = Transport([NodeConfig("http", "localhost", 80)], node_class="requests") assert isinstance(t.node_pool.get(), RequestsHttpNode) with pytest.raises(ValueError) as e: Transport([NodeConfig("http", "localhost", 80)], node_class="huh?") assert str(e.value) == ( "Unknown option for node_class: 'huh?'. " "Available options are: 'aiohttp', 'httpxasync', 'requests', 'urllib3'" ) @pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)]) def test_head_response_true(status, boolean): t = Transport( [NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})], node_class=DummyNode, ) resp, data = t.perform_request("HEAD", "/") assert resp.status == status assert data is None def test_head_response_false(): t = Transport( [NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})], node_class=DummyNode, ) meta, resp = t.perform_request("HEAD", "/") assert meta.status == 404 assert resp is None # 404s don't count as a dead node status. assert 0 == len(t.node_pool._dead_nodes.queue) @pytest.mark.parametrize( "node_class", ["urllib3", "requests", Urllib3HttpNode, RequestsHttpNode], ) def test_transport_client_meta_node_class(node_class): t = Transport([NodeConfig("http", "localhost", 80)], node_class=node_class) assert ( t._transport_client_meta[3] == t.node_pool.node_class._CLIENT_META_HTTP_CLIENT ) assert t._transport_client_meta[3][0] in ("ur", "rq") assert re.match( r"^et=[0-9.]+p?,py=[0-9.]+p?,t=[0-9.]+p?,(?:ur|rq)=[0-9.]+p?$", ",".join(f"{k}={v}" for k, v in t._transport_client_meta), ) # Defaults to urllib3 t = Transport([NodeConfig("http", "localhost", 80)]) assert t._transport_client_meta[3][0] == "ur" assert [x[0] for x in t._transport_client_meta[:3]] == ["et", "py", "t"] @pytest.mark.parametrize( "node_class", ["aiohttp", AiohttpHttpNode], ) def test_transport_and_node_are_sync(node_class): with pytest.raises(ValueError) as e: Transport([NodeConfig("http", "localhost", 80)], node_class=node_class) assert str(e.value) == "Specified 'node_class' is async, should be sync instead" def test_client_meta_header(): class DummyNodeWithClientMeta(DummyNode): _CLIENT_META_HTTP_CLIENT = ("dm", "0.0.0p") t = Transport( [NodeConfig("http", "localhost", 80)], node_class=DummyNodeWithClientMeta, client_meta_service=("es", "8.0.0p"), ) t.perform_request("GET", "/") calls = t.node_pool.get().calls assert 1 == len(calls) headers = calls[0][1]["headers"] assert sorted(headers.keys()) == ["x-elastic-client-meta"] assert re.match( r"^es=8\.0\.0p,py=[0-9.]+p?,t=[0-9.]+p?,dm=0\.0\.0p$", headers["x-elastic-client-meta"], ) def test_client_meta_header_extras(): class DummyNodeWithClientMeta(DummyNode): _CLIENT_META_HTTP_CLIENT = ("dm", "0.0.0p") t = Transport( [NodeConfig("http", "localhost", 80)], node_class=DummyNodeWithClientMeta, client_meta_service=("es", "8.0.0p"), ) t.perform_request("GET", "/", client_meta=(("h", "s"),)) calls = t.node_pool.get().calls assert 1 == len(calls) headers = calls[0][1]["headers"] assert sorted(headers.keys()) == ["x-elastic-client-meta"] assert re.match( r"^es=8\.0\.0p,py=[0-9.]+p?,t=[0-9.]+p?,dm=0\.0\.0p,h=s$", headers["x-elastic-client-meta"], ) def test_sniff_on_start(): calls = [] def sniff_callback(*args): nonlocal calls calls.append(args) return [NodeConfig("http", "localhost", 80)] t = Transport( [NodeConfig("http", "localhost", 80)], node_class=DummyNode, sniff_on_start=True, sniff_callback=sniff_callback, ) assert len(calls) == 1 t.perform_request("GET", "/") assert len(calls) == 1 transport, sniff_options = calls[0] assert transport is t assert sniff_options == SniffOptions(is_initial_sniff=True, sniff_timeout=0.5) def test_sniff_before_requests(): calls = [] def sniff_callback(*args): nonlocal calls calls.append(args) return [] t = Transport( [NodeConfig("http", "localhost", 80)], node_class=DummyNode, sniff_before_requests=True, sniff_callback=sniff_callback, ) assert len(calls) == 0 t.perform_request("GET", "/") assert len(calls) == 1 transport, sniff_options = calls[0] assert transport is t assert sniff_options == SniffOptions(is_initial_sniff=False, sniff_timeout=0.5) def test_sniff_on_node_failure(): calls = [] def sniff_callback(*args): nonlocal calls calls.append(args) return [] t = Transport( [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81, _extras={"status": 500}), ], randomize_nodes_in_pool=False, node_selector_class="round_robin", node_class=DummyNode, max_retries=1, sniff_on_node_failure=True, sniff_callback=sniff_callback, ) assert len(calls) == 0 t.perform_request("GET", "/") # 200 assert len(calls) == 0 t.perform_request("GET", "/") # 500 assert len(calls) == 1 transport, sniff_options = calls[0] assert transport is t assert sniff_options == SniffOptions(is_initial_sniff=False, sniff_timeout=0.5) @pytest.mark.parametrize( "kwargs", [ {"sniff_on_start": True}, {"sniff_on_node_failure": True}, {"sniff_before_requests": True}, ], ) def test_error_with_sniffing_enabled_without_callback(kwargs): with pytest.raises(ValueError) as e: Transport([NodeConfig("http", "localhost", 80)], **kwargs) assert str(e.value) == "Enabling sniffing requires specifying a 'sniff_callback'" def test_error_sniffing_callback_without_sniffing_enabled(): with pytest.raises(ValueError) as e: Transport([NodeConfig("http", "localhost", 80)], sniff_callback=lambda *_: []) assert str(e.value) == ( "Using 'sniff_callback' requires enabling sniffing via 'sniff_on_start', " "'sniff_before_requests' or 'sniff_on_node_failure'" ) def test_heterogeneous_node_config_warning_with_sniffing(): with warnings.catch_warnings(record=True) as w: context = ssl.create_default_context() Transport( [ NodeConfig( "https", "localhost", 80, path_prefix="/a", ssl_context=context ), NodeConfig( "https", "localhost", 81, path_prefix="/b", ssl_context=context ), ], sniff_on_start=True, sniff_callback=lambda *_: [ NodeConfig("https", "localhost", 80, path_prefix="/a") ], ) assert len(w) == 1 assert w[0].category == TransportWarning assert str(w[0].message) == ( "Detected NodeConfig instances with different options. It's " "recommended to keep all options except for 'host' and 'port' " "the same for sniffing to work reliably." ) def test_sniffed_nodes_added_to_pool(): sniffed_nodes = [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81), ] t = Transport( [ NodeConfig("http", "localhost", 80), ], node_class=DummyNode, sniff_before_requests=True, sniff_callback=lambda *_: sniffed_nodes, ) assert len(t.node_pool) == 1 t.perform_request("GET", "/") # The node pool knows when nodes are already in the pool # so we shouldn't get duplicates after sniffing. assert len(t.node_pool) == 2 assert set(sniffed_nodes) == {node.config for node in t.node_pool.all()} def test_sniff_error_resets_lock_and_last_sniffed_at(): def sniff_error(*_): raise TransportError("This is an error!") t = Transport( [ NodeConfig("http", "localhost", 80), ], node_class=DummyNode, sniff_before_requests=True, sniff_callback=sniff_error, ) last_sniffed_at = t._last_sniffed_at with pytest.raises(TransportError) as e: t.perform_request("GET", "/") assert str(e.value) == "This is an error!" assert t._last_sniffed_at == last_sniffed_at assert t._sniffing_lock.locked() is False def test_sniff_on_start_no_results_errors(): with pytest.raises(SniffingError) as e: Transport( [ NodeConfig("http", "localhost", 80), ], node_class=DummyNode, sniff_on_start=True, sniff_callback=lambda *_: [], ) assert ( str(e.value) == "No viable nodes were discovered on the initial sniff attempt" ) @pytest.mark.parametrize("pool_size", [1, 8]) def test_threading_test(pool_size): node_configs = [ NodeConfig("http", "localhost", 80), NodeConfig("http", "localhost", 81), NodeConfig("http", "localhost", 82), NodeConfig("http", "localhost", 83, _extras={"status": 500}), ] def sniff_callback(*_): time.sleep(random.random()) return node_configs t = Transport( node_configs, retry_on_status=[500], max_retries=5, node_class=DummyNode, sniff_on_start=True, sniff_before_requests=True, sniff_on_node_failure=True, sniff_callback=sniff_callback, ) class ThreadTest(threading.Thread): def __init__(self): super().__init__() self.successful_requests = 0 def run(self) -> None: nonlocal t, start while time.time() < start + 2: t.perform_request("GET", "/") self.successful_requests += 1 threads = [ThreadTest() for _ in range(pool_size * 2)] start = time.time() [thread.start() for thread in threads] [thread.join() for thread in threads] assert sum(thread.successful_requests for thread in threads) >= 1000 def test_httpbin(httpbin_node_config): t = Transport([httpbin_node_config]) resp = t.perform_request("GET", "/anything") assert resp.meta.status == 200 assert isinstance(resp.body, dict) elastic-transport-python-8.17.1/tests/test_utils.py000066400000000000000000000030301476450415400225260ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest from elastic_transport._utils import is_ipaddress @pytest.mark.parametrize( "addr", [ # IPv6 "::1", "::", "FE80::8939:7684:D84b:a5A4%251", # IPv4 "127.0.0.1", "8.8.8.8", b"127.0.0.1", # IPv6 w/ Zone IDs "FE80::8939:7684:D84b:a5A4%251", b"FE80::8939:7684:D84b:a5A4%251", "FE80::8939:7684:D84b:a5A4%19", b"FE80::8939:7684:D84b:a5A4%19", ], ) def test_is_ipaddress(addr): assert is_ipaddress(addr) @pytest.mark.parametrize( "addr", [ "www.python.org", b"www.python.org", "v2.sg.media-imdb.com", b"v2.sg.media-imdb.com", ], ) def test_is_not_ipaddress(addr): assert not is_ipaddress(addr) elastic-transport-python-8.17.1/utils/000077500000000000000000000000001476450415400177575ustar00rootroot00000000000000elastic-transport-python-8.17.1/utils/build-dists.py000066400000000000000000000071321476450415400225570ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """A command line tool for building and verifying releases Can be used for building both 'elasticsearch' and 'elasticsearchX' dists. Only requires 'name' in 'setup.py' and the directory to be changed. """ import contextlib import os import re import shutil import tempfile base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) tmp_dir = None def shlex_quote(s): # Backport of shlex.quote() to Python 2.x _find_unsafe = re.compile(r"[^\w@%+=:,./-]").search if not s: return "''" if _find_unsafe(s) is None: return s # use single quotes, and put single quotes into double quotes # the string $'b is then quoted as '$'"'"'b' return "'" + s.replace("'", "'\"'\"'") + "'" @contextlib.contextmanager def set_tmp_dir(): global tmp_dir tmp_dir = tempfile.mkdtemp() yield tmp_dir shutil.rmtree(tmp_dir) tmp_dir = None def run(argv, expect_exit_code=0): global tmp_dir if tmp_dir is None: os.chdir(base_dir) else: os.chdir(tmp_dir) cmd = " ".join(shlex_quote(x) for x in argv) print("$ " + cmd) exit_code = os.system(cmd) if exit_code != expect_exit_code: print( "Command exited incorrectly: should have been %d was %d" % (expect_exit_code, exit_code) ) exit(exit_code or 1) def test_dist(dist): with set_tmp_dir() as tmp_dir: # Build the venv and install the dist run(("python", "-m", "venv", os.path.join(tmp_dir, "venv"))) venv_python = os.path.join(tmp_dir, "venv/bin/python") run((venv_python, "-m", "pip", "install", "-U", "pip")) run((venv_python, "-m", "pip", "install", dist)) # Test out importing from the package run( ( venv_python, "-c", "from elastic_transport import Transport, Urllib3HttpNode, RequestsHttpNode", ) ) # Uninstall the dist, see that we can't import things anymore run((venv_python, "-m", "pip", "uninstall", "--yes", "elastic-transport")) run( (venv_python, "-c", "from elastic_transport import Transport"), expect_exit_code=256, ) def main(): run(("rm", "-rf", "build/", "dist/", "*.egg-info", ".eggs")) # Install and run python-build to create sdist/wheel run(("python", "-m", "pip", "install", "-U", "build")) run(("python", "-m", "build")) for dist in os.listdir(os.path.join(base_dir, "dist")): test_dist(os.path.join(base_dir, "dist", dist)) # After this run 'python -m twine upload dist/*' print( "\n\n" "===============================\n\n" " * Releases are ready! *\n\n" "$ python -m twine upload dist/*\n\n" "===============================" ) if __name__ == "__main__": main() elastic-transport-python-8.17.1/utils/license-headers.py000066400000000000000000000104471476450415400233720ustar00rootroot00000000000000# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Script which verifies that all source files have a license header. Has two modes: 'fix' and 'check'. 'fix' fixes problems, 'check' will error out if 'fix' would have changed the file. """ import os import sys from itertools import chain from typing import Iterator, List lines_to_keep = ["# -*- coding: utf-8 -*-\n", "#!/usr/bin/env python\n"] license_header_lines = [ "# Licensed to Elasticsearch B.V. under one or more contributor\n", "# license agreements. See the NOTICE file distributed with\n", "# this work for additional information regarding copyright\n", "# ownership. Elasticsearch B.V. licenses this file to you under\n", '# the Apache License, Version 2.0 (the "License"); you may\n', "# not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing,\n", "# software distributed under the License is distributed on an\n", '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n', "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License.\n", "\n", ] def find_files_to_fix(sources: List[str]) -> Iterator[str]: """Iterates over all files and dirs in 'sources' and returns only the filepaths that need fixing. """ for source in sources: if os.path.isfile(source) and does_file_need_fix(source): yield source elif os.path.isdir(source): for root, _, filenames in os.walk(source): for filename in filenames: filepath = os.path.join(root, filename) if does_file_need_fix(filepath): yield filepath def does_file_need_fix(filepath: str) -> bool: if not filepath.endswith(".py"): return False with open(filepath) as f: first_license_line = None for line in f: if line == license_header_lines[0]: first_license_line = line break elif line not in lines_to_keep: return True for header_line, line in zip( license_header_lines, chain((first_license_line,), f) ): if line != header_line: return True return False def add_header_to_file(filepath: str) -> None: with open(filepath) as f: lines = list(f) i = 0 for i, line in enumerate(lines): if line not in lines_to_keep: break lines = lines[:i] + license_header_lines + lines[i:] with open(filepath, mode="w") as f: f.truncate() f.write("".join(lines)) print(f"Fixed {os.path.relpath(filepath, os.getcwd())}") def main(): mode = sys.argv[1] assert mode in ("fix", "check") sources = [os.path.abspath(x) for x in sys.argv[2:]] files_to_fix = find_files_to_fix(sources) if mode == "fix": for filepath in files_to_fix: add_header_to_file(filepath) else: no_license_headers = list(files_to_fix) if no_license_headers: print("No license header found in:") cwd = os.getcwd() [ print(f" - {os.path.relpath(filepath, cwd)}") for filepath in no_license_headers ] sys.exit(1) else: print("All files had license header") if __name__ == "__main__": main()