pax_global_header00006660000000000000000000000064147670077030014525gustar00rootroot0000000000000052 comment=1b0de51538deb7c21d0c268f36764a8589e40012 array-api-compat-1.11.2/000077500000000000000000000000001476700770300147555ustar00rootroot00000000000000array-api-compat-1.11.2/.github/000077500000000000000000000000001476700770300163155ustar00rootroot00000000000000array-api-compat-1.11.2/.github/dependabot.yml000066400000000000000000000004761476700770300211540ustar00rootroot00000000000000version: 2 updates: # Maintain dependencies for GitHub Actions - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" groups: actions: patterns: - "*" labels: - "github-actions" - "dependencies" reviewers: - "asmeurer" array-api-compat-1.11.2/.github/workflows/000077500000000000000000000000001476700770300203525ustar00rootroot00000000000000array-api-compat-1.11.2/.github/workflows/array-api-tests-dask.yml000066400000000000000000000012611476700770300250420ustar00rootroot00000000000000name: Array API Tests (Dask) on: [push, pull_request] jobs: array-api-tests-dask: uses: ./.github/workflows/array-api-tests.yml with: package-name: dask package-version: '>= 2024.9.0' module-name: dask.array extra-requires: numpy # Dask is substantially slower then other libraries on unit tests. # Reduce the number of examples to speed up CI, even though this means that this # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. pytest-extra-args: --max-examples=5 array-api-compat-1.11.2/.github/workflows/array-api-tests-numpy-1-21.yml000066400000000000000000000003741476700770300256520ustar00rootroot00000000000000name: Array API Tests (NumPy 1.21) on: [push, pull_request] jobs: array-api-tests-numpy-1-21: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy package-version: '== 1.21.*' xfails-file-extra: '-1-21' array-api-compat-1.11.2/.github/workflows/array-api-tests-numpy-1-26.yml000066400000000000000000000003761476700770300256610ustar00rootroot00000000000000name: Array API Tests (NumPy 1.26) on: [push, pull_request] jobs: array-api-tests-numpy-latest: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy package-version: '== 1.26.*' xfails-file-extra: '-1-26' array-api-compat-1.11.2/.github/workflows/array-api-tests-numpy-dev.yml000066400000000000000000000005041476700770300260430ustar00rootroot00000000000000name: Array API Tests (NumPy dev) on: [push, pull_request] jobs: array-api-tests-numpy-dev: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' array-api-compat-1.11.2/.github/workflows/array-api-tests-numpy-latest.yml000066400000000000000000000002741476700770300265650ustar00rootroot00000000000000name: Array API Tests (NumPy Latest) on: [push, pull_request] jobs: array-api-tests-numpy-latest: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy array-api-compat-1.11.2/.github/workflows/array-api-tests-torch.yml000066400000000000000000000004101476700770300252320ustar00rootroot00000000000000name: Array API Tests (PyTorch Latest) on: [push, pull_request] jobs: array-api-tests-torch: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 array-api-compat-1.11.2/.github/workflows/array-api-tests.yml000066400000000000000000000071531476700770300241300ustar00rootroot00000000000000name: Array API Tests on: workflow_call: inputs: package-name: required: true type: string module-name: required: false type: string extra-requires: required: false type: string package-version: required: false type: string default: '>= 0' pytest-extra-args: required: false type: string # This is not how I would prefer to implement this but it's the only way # that seems possible with GitHub Actions' limited expressions syntax xfails-file-extra: required: false type: string skips-file-extra: required: false type: string extra-env-vars: required: false type: string description: "Multiline string of environment variables to set for the test run." env: PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" jobs: tests: runs-on: ubuntu-latest strategy: matrix: # Min version of dask we need dropped support for Python 3.9 # There is no numpy git tip for Python 3.9 or 3.10 python-version: ${{ (inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']')) || (inputs.package-name == 'numpy' && inputs.xfails-file-extra == '-dev' && fromJson('[''3.11'', ''3.12'']')) || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }} steps: - name: Checkout array-api-compat uses: actions/checkout@v4 with: path: array-api-compat - name: Checkout array-api-tests uses: actions/checkout@v4 with: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Set Extra Environment Variables # Set additional environment variables if provided if: inputs.extra-env-vars run: | echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV - name: Install dependencies # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way # to put this in the numpy 1.21 config file. if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt - name: Run the array API testsuite (${{ inputs.package-name }}) if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} ARRAY_API_TESTS_VERSION: 2024.12 # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak run: | export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat" cd ${GITHUB_WORKSPACE}/array-api-tests pytest array_api_tests/ --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.xfails-file-extra }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.skips-file-extra}}-skips.txt ${PYTEST_ARGS} array-api-compat-1.11.2/.github/workflows/dependabot-auto-merge.yml000066400000000000000000000013301476700770300252420ustar00rootroot00000000000000# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/automating-dependabot-with-github-actions#approve-a-pull-request name: Dependabot auto-merge on: pull_request permissions: contents: write pull-requests: write jobs: dependabot: runs-on: ubuntu-latest if: github.actor == 'dependabot[bot]' steps: - name: Dependabot metadata id: metadata uses: dependabot/fetch-metadata@v2 with: github-token: "${{ secrets.GITHUB_TOKEN }}" - name: Enable auto-merge for Dependabot PRs run: gh pr merge --auto --merge "$PR_URL" env: PR_URL: ${{github.event.pull_request.html_url}} GH_TOKEN: ${{secrets.GITHUB_TOKEN}} array-api-compat-1.11.2/.github/workflows/docs-build.yml000066400000000000000000000007701476700770300231260ustar00rootroot00000000000000name: Docs Build on: [push, pull_request] jobs: docs-build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 - name: Install Dependencies run: | python -m pip install -r docs/requirements.txt - name: Build Docs run: | cd docs make html - name: Upload Artifact uses: actions/upload-artifact@v4 with: name: docs-build path: docs/_build/html array-api-compat-1.11.2/.github/workflows/docs-deploy.yml000066400000000000000000000013701476700770300233200ustar00rootroot00000000000000name: Docs Deploy on: push: branches: - main jobs: docs-deploy: runs-on: ubuntu-latest environment: name: docs-deploy steps: - uses: actions/checkout@v4 - name: Download Artifact uses: dawidd6/action-download-artifact@v9 with: workflow: docs-build.yml name: docs-build path: docs/_build/html # Note, the gh-pages deployment requires setting up a SSH deploy key. # See # https://github.com/JamesIves/github-pages-deploy-action/tree/dev#using-an-ssh-deploy-key- - name: Deploy uses: JamesIves/github-pages-deploy-action@v4 with: folder: docs/_build/html ssh-key: ${{ secrets.DEPLOY_KEY }} force: no array-api-compat-1.11.2/.github/workflows/publish-package.yml000066400000000000000000000060741476700770300241430ustar00rootroot00000000000000name: publish distributions on: push: branches: - main tags: - '[0-9]+.[0-9]+' - '[0-9]+.[0-9]+.[0-9]+' pull_request: branches: - main release: types: [published] workflow_dispatch: inputs: publish: type: choice description: 'Publish to TestPyPI?' options: - false - true concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: build: name: Build Python distribution runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install python-build and twine run: | python -m pip install --upgrade pip "setuptools<=67" python -m pip install build twine python -m pip list - name: Build a wheel and a sdist run: | #PYTHONWARNINGS=error,default::DeprecationWarning python -m build . python -m build . - name: Verify the distribution run: twine check --strict dist/* - name: List contents of sdist run: python -m tarfile --list dist/array_api_compat-*.tar.gz - name: List contents of wheel run: python -m zipfile --list dist/array_api_compat-*.whl - name: Upload distribution artifact uses: actions/upload-artifact@v4 with: name: dist-artifact path: dist publish: name: Publish Python distribution to (Test)PyPI if: github.event_name != 'pull_request' && github.repository == 'data-apis/array-api-compat' && github.ref_type == 'tag' needs: build runs-on: ubuntu-latest # Mandatory for publishing with a trusted publisher # c.f. https://docs.pypi.org/trusted-publishers/using-a-publisher/ permissions: id-token: write contents: write # Restrict to the environment set for the trusted publisher environment: name: publish-package steps: - name: Download distribution artifact uses: actions/download-artifact@v4 with: name: dist-artifact path: dist - name: List all files run: ls -lh dist # - name: Publish distribution 📦 to Test PyPI # # Publish to TestPyPI on tag events of if manually triggered # # Compare to 'true' string as booleans get turned into strings in the console # if: >- # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') # uses: pypa/gh-action-pypi-publish@v1.12.4 # with: # repository-url: https://test.pypi.org/legacy/ # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@v1.12.4 with: print-hash: true - name: Create GitHub Release from a Tag uses: softprops/action-gh-release@v2 if: startsWith(github.ref, 'refs/tags/') with: files: dist/* array-api-compat-1.11.2/.github/workflows/ruff.yml000066400000000000000000000010241476700770300220340ustar00rootroot00000000000000name: CI on: [push, pull_request] jobs: check-ruff: runs-on: ubuntu-latest continue-on-error: true steps: - uses: actions/checkout@v4 - name: Install Python uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip pip install ruff # Update output format to enable automatic inline annotations. - name: Run Ruff run: ruff check --output-format=github . array-api-compat-1.11.2/.github/workflows/tests.yml000066400000000000000000000030221476700770300222340ustar00rootroot00000000000000name: Tests on: [push, pull_request] jobs: tests: runs-on: ubuntu-latest strategy: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] numpy-version: ['1.21', '1.26', '2.0', 'dev'] exclude: - python-version: '3.11' numpy-version: '1.21' - python-version: '3.12' numpy-version: '1.21' fail-fast: true steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install Dependencies run: | python -m pip install --upgrade pip if [ "${{ matrix.numpy-version }}" == "dev" ]; then PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then PIP_EXTRA='numpy==1.21.*' else PIP_EXTRA='numpy==1.26.*' fi if [ "${{ matrix.python-version }}" == "3.9" ]; then sed -i '/^ndonnx/d' requirements-dev.txt fi python -m pip install -r requirements-dev.txt $PIP_EXTRA - name: Run Tests run: | if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") fi pytest -v "${PYTEST_EXTRA[@]}" # Make sure it installs python -m pip install . array-api-compat-1.11.2/.gitignore000066400000000000000000000034501476700770300167470ustar00rootroot00000000000000# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # macOS specific iles .DS_Store array-api-compat-1.11.2/CHANGELOG.md000077700000000000000000000000001476700770300217622docs/changelog.mdustar00rootroot00000000000000array-api-compat-1.11.2/CONTRIBUTING.md000066400000000000000000000005711476700770300172110ustar00rootroot00000000000000Contributions to array-api-compat are welcome, so long as they are [in scope](https://data-apis.org/array-api-compat/index.html#scope). Contributors are encouraged to read through the [development notes](https://data-apis.org/array-api-compat/dev/index.html) for the package to get full context on some of the design decisions and implementation details used in the codebase. array-api-compat-1.11.2/LICENSE000066400000000000000000000021111476700770300157550ustar00rootroot00000000000000MIT License Copyright (c) 2022 Consortium for Python Data API Standards Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. array-api-compat-1.11.2/README.md000066400000000000000000000007601476700770300162370ustar00rootroot00000000000000# Array API compatibility library This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want support for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). See the documentation for more details https://data-apis.org/array-api-compat/ array-api-compat-1.11.2/array_api_compat/000077500000000000000000000000001476700770300202675ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/__init__.py000066400000000000000000000017401476700770300224020ustar00rootroot00000000000000""" NumPy Array API compatibility library This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are compatible with the Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. Unlike array_api_strict, this is not a strict minimal implementation of the Array API, but rather just an extension of the main NumPy namespace with changes needed to be compliant with the Array API. See https://numpy.org/doc/stable/reference/array_api.html for a full list of changes. In particular, unlike array_api_strict, this package does not use a separate Array object, but rather just uses numpy.ndarray directly. Library authors using the Array API may wish to test against array_api_strict to ensure they are not using functionality outside of the standard, but prefer this implementation for the default when working with NumPy arrays. """ __version__ = '1.11.2' from .common import * # noqa: F401, F403 array-api-compat-1.11.2/array_api_compat/_internal.py000066400000000000000000000017621476700770300226220ustar00rootroot00000000000000""" Internal helpers """ from functools import wraps from inspect import signature def get_xp(xp): """ Decorator to automatically replace xp with the corresponding array module. Use like import numpy as np @get_xp(np) def func(x, /, xp, kwarg=None): return xp.func(x, kwarg=kwarg) Note that xp must be a keyword argument and come after all non-keyword arguments. """ def inner(f): @wraps(f) def wrapped_f(*args, **kwargs): return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] ) if wrapped_f.__doc__ is None: wrapped_f.__doc__ = f"""\ Array API compatibility wrapper for {f.__name__}. See the corresponding documentation in NumPy/CuPy and/or the array API specification for more details. """ wrapped_f.__signature__ = new_sig return wrapped_f return inner array-api-compat-1.11.2/array_api_compat/common/000077500000000000000000000000001476700770300215575ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/common/__init__.py000066400000000000000000000000451476700770300236670ustar00rootroot00000000000000from ._helpers import * # noqa: F403 array-api-compat-1.11.2/array_api_compat/common/_aliases.py000066400000000000000000000432761476700770300237250ustar00rootroot00000000000000""" These are functions that are just aliases of existing functions in NumPy. """ from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Sequence, Tuple, Union from ._typing import ndarray, Device, Dtype from typing import NamedTuple import inspect from ._helpers import array_namespace, _check_device, device, is_cupy_namespace # These functions are modified from the NumPy versions. # Creation functions add the device keyword (which does nothing for NumPy) def arange( start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs ) -> ndarray: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs ) -> ndarray: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs ) -> ndarray: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) def eye( n_rows: int, n_cols: Optional[int] = None, /, *, xp, k: int = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( x: ndarray, /, fill_value: Union[int, float], *, xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) def linspace( start: Union[int, float], stop: Union[int, float], /, num: int, *, xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> ndarray: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). # The functions here return namedtuples (np.unique() returns a normal # tuple). # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): values: ndarray indices: ndarray inverse_indices: ndarray counts: ndarray class UniqueCountsResult(NamedTuple): values: ndarray counts: ndarray class UniqueInverseResult(NamedTuple): values: ndarray inverse_indices: ndarray def _unique_kwargs(xp): # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. s = inspect.signature(xp.unique) if 'equal_nan' in s.parameters: return {'equal_nan': False} return {} def unique_all(x: ndarray, /, xp) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, return_counts=True, return_index=True, return_inverse=True, **kwargs, ) # np.unique() flattens inverse indices, but they need to share x's shape # See https://github.com/numpy/numpy/issues/20638 inverse_indices = inverse_indices.reshape(x.shape) return UniqueAllResult( values, indices, inverse_indices, counts, ) def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( x, return_counts=True, return_index=False, return_inverse=False, **kwargs ) return UniqueCountsResult(*res) def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, return_counts=False, return_index=False, return_inverse=True, **kwargs, ) # xp.unique() flattens inverse indices, but they need to share x's shape # See https://github.com/numpy/numpy/issues/20638 inverse_indices = inverse_indices.reshape(x.shape) return UniqueInverseResult(values, inverse_indices) def unique_values(x: ndarray, /, xp) -> ndarray: kwargs = _unique_kwargs(xp) return xp.unique( x, return_counts=False, return_index=False, return_inverse=False, **kwargs, ) # These functions have different keyword argument names def std( x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, ) -> ndarray: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) def var( x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, ) -> ndarray: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument def cumulative_sum( x: ndarray, /, xp, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False, **kwargs ) -> ndarray: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: raise ValueError("axis must be specified in cumulative_sum for more than one dimension") axis = 0 res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) # np.cumsum does not support include_initial if include_initial: initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], axis=axis, ) return res def cumulative_prod( x: ndarray, /, xp, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False, **kwargs ) -> ndarray: wrapped_xp = array_namespace(x) if axis is None: if x.ndim > 1: raise ValueError("axis must be specified in cumulative_prod for more than one dimension") axis = 0 res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) # np.cumprod does not support include_initial if include_initial: initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], axis=axis, ) return res # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( x: ndarray, /, min: Optional[Union[int, float, ndarray]] = None, max: Optional[Union[int, float, ndarray]] = None, *, xp, # TODO: np.clip has other ufunc kwargs out: Optional[ndarray] = None, ) -> ndarray: def _isscalar(a): return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape wrapped_xp = array_namespace(x) result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) # np.clip does type promotion but the array API clip requires that the # output have the same dtype as x. We do this instead of just downcasting # the result of xp.clip() to handle some corner cases better (e.g., # avoiding uint64 -> float64 promotion). # Note: cases where min or max overflow (integer) or round (float) in the # wrong direction when downcasting to x.dtype are unspecified. This code # just does whatever NumPy does when it downcasts in the assignment, but # other behavior could be preferred, especially for integers. For example, # this code produces: # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) # -128 # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). if wrapped_xp.isdtype(x.dtype, "integral"): if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: min = None if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None dev = device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) out[()] = x if min is not None: a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev) a = xp.broadcast_to(a, result_shape) ia = (out < a) | xp.isnan(a) out[ia] = a[ia] if max is not None: b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev) b = xp.broadcast_to(b, result_shape) ib = (out > b) | xp.isnan(b) out[ib] = b[ib] # Return a scalar for 0-D return out[()] # Unlike transpose(), the axes argument to permute_dims() is required. def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: return xp.transpose(x, axes) # np.reshape calls the keyword argument 'newshape' instead of 'shape' def reshape(x: ndarray, /, shape: Tuple[int, ...], xp, copy: Optional[bool] = None, **kwargs) -> ndarray: if copy is True: x = x.copy() elif copy is False: y = x.view() y.shape = shape return y return xp.reshape(x, shape, **kwargs) # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs, ) -> ndarray: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: kwargs['kind'] = "stable" if not descending: res = xp.argsort(x, axis=axis, **kwargs) else: # As NumPy has no native descending sort, we imitate it here. Note that # simply flipping the results of xp.argsort(x, ...) would not # respect the relative order like it would in native descending sorts. res = xp.flip( xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs), axis=axis, ) # Rely on flip()/argsort() to validate axis normalised_axis = axis if axis >= 0 else x.ndim + axis max_i = x.shape[normalised_axis] - 1 res = max_i - res return res def sort( x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs, ) -> ndarray: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: kwargs['kind'] = "stable" res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res # nonzero should error for zero-dimensional arrays def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) # ceil, floor, and trunc return integers for integer inputs def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) def floor(x: ndarray, /, xp, **kwargs) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) # linear algebra functions def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: return xp.matmul(x1, x2, **kwargs) # Unlike transpose, matrix_transpose only transposes the last two axes. def matrix_transpose(x: ndarray, /, xp) -> ndarray: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) def tensordot(x1: ndarray, x2: ndarray, /, xp, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs, ) -> ndarray: return xp.tensordot(x1, x2, axes=axes, **kwargs) def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") if hasattr(xp, 'broadcast_tensors'): _broadcast = xp.broadcast_tensors else: _broadcast = xp.broadcast_arrays x1_ = xp.moveaxis(x1, axis, -1) x2_ = xp.moveaxis(x2, axis, -1) x1_, x2_ = _broadcast(x1_, x2_) res = xp.conj(x1_[..., None, :]) @ x2_[..., None] return res[..., 0, 0] # isdtype is a new function in the 2022.12 array API specification. def isdtype( dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, *, _tuple=True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. Note that outside of this function, this compat library does not yet fully support complex numbers. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html for more details """ if isinstance(kind, tuple) and _tuple: return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) elif isinstance(kind, str): if kind == 'bool': return dtype == xp.bool_ elif kind == 'signed integer': return xp.issubdtype(dtype, xp.signedinteger) elif kind == 'unsigned integer': return xp.issubdtype(dtype, xp.unsignedinteger) elif kind == 'integral': return xp.issubdtype(dtype, xp.integer) elif kind == 'real floating': return xp.issubdtype(dtype, xp.floating) elif kind == 'complex floating': return xp.issubdtype(dtype, xp.complexfloating) elif kind == 'numeric': return xp.issubdtype(dtype, xp.number) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") else: # This will allow things that aren't required by the spec, like # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be # more strict here to match the type annotation? Note that the # array_api_strict implementation will be very strict. return dtype == kind # unstack is a new function in the 2023.12 array API standard def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) # numpy 1.26 does not use the standard definition for sign on complex numbers def sign(x: ndarray, /, xp, **kwargs) -> ndarray: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan out[x == 0+0j] = 0+0j else: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): out[xp.isnan(x)] = xp.nan return out[()] __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack', 'sign'] array-api-compat-1.11.2/array_api_compat/common/_fft.py000066400000000000000000000112731476700770300230530ustar00rootroot00000000000000from __future__ import annotations from typing import TYPE_CHECKING, Union, Optional, Literal if TYPE_CHECKING: from ._typing import Device, ndarray, DType from collections.abc import Sequence # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( x: ndarray, /, xp, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( x: ndarray, /, xp, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( x: ndarray, /, xp, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( x: ndarray, /, xp, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( x: ndarray, /, xp, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( x: ndarray, /, xp, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( x: ndarray, /, xp, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( x: ndarray, /, xp, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( x: ndarray, /, xp, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( x: ndarray, /, xp, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> ndarray: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftfreq( n: int, /, xp, *, d: float = 1.0, dtype: Optional[DType] = None, device: Optional[Device] = None ) -> ndarray: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.fftfreq(n, d=d) if dtype is not None: return res.astype(dtype) return res def rfftfreq( n: int, /, xp, *, d: float = 1.0, dtype: Optional[DType] = None, device: Optional[Device] = None ) -> ndarray: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.rfftfreq(n, d=d) if dtype is not None: return res.astype(dtype) return res def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: return xp.fft.fftshift(x, axes=axes) def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: return xp.fft.ifftshift(x, axes=axes) __all__ = [ "fft", "ifft", "fftn", "ifftn", "rfft", "irfft", "rfftn", "irfftn", "hfft", "ihfft", "fftfreq", "rfftfreq", "fftshift", "ifftshift", ] array-api-compat-1.11.2/array_api_compat/common/_helpers.py000066400000000000000000000662511476700770300237440ustar00rootroot00000000000000""" Various helper functions which are not part of the spec. Functions which start with an underscore are for internal use only but helpers that are in __all__ are intended as additional helper functions for use by end users of the compat library. """ from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union, Any from ._typing import Array, Device, Namespace import sys import math import inspect import warnings def _is_jax_zero_gradient_array(x: object) -> bool: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ if 'numpy' not in sys.modules or 'jax' not in sys.modules: return False import numpy as np import jax return isinstance(x, np.ndarray) and x.dtype == jax.float0 def is_numpy_array(x: object) -> bool: """ Return True if `x` is a NumPy array. This function does not import NumPy if it has not already been imported and is therefore cheap to use. This also returns True for `ndarray` subclasses and NumPy scalar objects. See Also -------- array_namespace is_array_api_obj is_cupy_array is_torch_array is_ndonnx_array is_dask_array is_jax_array is_pydata_sparse_array """ # Avoid importing NumPy if it isn't already if 'numpy' not in sys.modules: return False import numpy as np # TODO: Should we reject ndarray subclasses? return (isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array(x)) def is_cupy_array(x: object) -> bool: """ Return True if `x` is a CuPy array. This function does not import CuPy if it has not already been imported and is therefore cheap to use. This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects. See Also -------- array_namespace is_array_api_obj is_numpy_array is_torch_array is_ndonnx_array is_dask_array is_jax_array is_pydata_sparse_array """ # Avoid importing CuPy if it isn't already if 'cupy' not in sys.modules: return False import cupy as cp # TODO: Should we reject ndarray subclasses? return isinstance(x, cp.ndarray) def is_torch_array(x: object) -> bool: """ Return True if `x` is a PyTorch tensor. This function does not import PyTorch if it has not already been imported and is therefore cheap to use. See Also -------- array_namespace is_array_api_obj is_numpy_array is_cupy_array is_dask_array is_jax_array is_pydata_sparse_array """ # Avoid importing torch if it isn't already if 'torch' not in sys.modules: return False import torch # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) def is_ndonnx_array(x: object) -> bool: """ Return True if `x` is a ndonnx Array. This function does not import ndonnx if it has not already been imported and is therefore cheap to use. See Also -------- array_namespace is_array_api_obj is_numpy_array is_cupy_array is_ndonnx_array is_dask_array is_jax_array is_pydata_sparse_array """ # Avoid importing torch if it isn't already if 'ndonnx' not in sys.modules: return False import ndonnx as ndx return isinstance(x, ndx.Array) def is_dask_array(x: object) -> bool: """ Return True if `x` is a dask.array Array. This function does not import dask if it has not already been imported and is therefore cheap to use. See Also -------- array_namespace is_array_api_obj is_numpy_array is_cupy_array is_torch_array is_ndonnx_array is_jax_array is_pydata_sparse_array """ # Avoid importing dask if it isn't already if 'dask.array' not in sys.modules: return False import dask.array return isinstance(x, dask.array.Array) def is_jax_array(x: object) -> bool: """ Return True if `x` is a JAX array. This function does not import JAX if it has not already been imported and is therefore cheap to use. See Also -------- array_namespace is_array_api_obj is_numpy_array is_cupy_array is_torch_array is_ndonnx_array is_dask_array is_pydata_sparse_array """ # Avoid importing jax if it isn't already if 'jax' not in sys.modules: return False import jax return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) def is_pydata_sparse_array(x) -> bool: """ Return True if `x` is an array from the `sparse` package. This function does not import `sparse` if it has not already been imported and is therefore cheap to use. See Also -------- array_namespace is_array_api_obj is_numpy_array is_cupy_array is_torch_array is_ndonnx_array is_dask_array is_jax_array """ # Avoid importing jax if it isn't already if 'sparse' not in sys.modules: return False import sparse # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) def is_array_api_obj(x: object) -> bool: """ Return True if `x` is an array API compatible array object. See Also -------- array_namespace is_numpy_array is_cupy_array is_torch_array is_ndonnx_array is_dask_array is_jax_array """ return is_numpy_array(x) \ or is_cupy_array(x) \ or is_torch_array(x) \ or is_dask_array(x) \ or is_jax_array(x) \ or is_pydata_sparse_array(x) \ or hasattr(x, '__array_namespace__') def _compat_module_name() -> str: assert __name__.endswith('.common._helpers') return __name__.removesuffix('.common._helpers') def is_numpy_namespace(xp) -> bool: """ Returns True if `xp` is a NumPy namespace. This includes both NumPy itself and the version wrapped by array-api-compat. See Also -------- array_namespace is_cupy_namespace is_torch_namespace is_ndonnx_namespace is_dask_namespace is_jax_namespace is_pydata_sparse_namespace is_array_api_strict_namespace """ return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} def is_cupy_namespace(xp) -> bool: """ Returns True if `xp` is a CuPy namespace. This includes both CuPy itself and the version wrapped by array-api-compat. See Also -------- array_namespace is_numpy_namespace is_torch_namespace is_ndonnx_namespace is_dask_namespace is_jax_namespace is_pydata_sparse_namespace is_array_api_strict_namespace """ return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} def is_torch_namespace(xp) -> bool: """ Returns True if `xp` is a PyTorch namespace. This includes both PyTorch itself and the version wrapped by array-api-compat. See Also -------- array_namespace is_numpy_namespace is_cupy_namespace is_ndonnx_namespace is_dask_namespace is_jax_namespace is_pydata_sparse_namespace is_array_api_strict_namespace """ return xp.__name__ in {'torch', _compat_module_name() + '.torch'} def is_ndonnx_namespace(xp) -> bool: """ Returns True if `xp` is an NDONNX namespace. See Also -------- array_namespace is_numpy_namespace is_cupy_namespace is_torch_namespace is_dask_namespace is_jax_namespace is_pydata_sparse_namespace is_array_api_strict_namespace """ return xp.__name__ == 'ndonnx' def is_dask_namespace(xp) -> bool: """ Returns True if `xp` is a Dask namespace. This includes both ``dask.array`` itself and the version wrapped by array-api-compat. See Also -------- array_namespace is_numpy_namespace is_cupy_namespace is_torch_namespace is_ndonnx_namespace is_jax_namespace is_pydata_sparse_namespace is_array_api_strict_namespace """ return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} def is_jax_namespace(xp) -> bool: """ Returns True if `xp` is a JAX namespace. This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in older versions of JAX. See Also -------- array_namespace is_numpy_namespace is_cupy_namespace is_torch_namespace is_ndonnx_namespace is_dask_namespace is_pydata_sparse_namespace is_array_api_strict_namespace """ return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} def is_pydata_sparse_namespace(xp) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. See Also -------- array_namespace is_numpy_namespace is_cupy_namespace is_torch_namespace is_ndonnx_namespace is_dask_namespace is_jax_namespace is_array_api_strict_namespace """ return xp.__name__ == 'sparse' def is_array_api_strict_namespace(xp) -> bool: """ Returns True if `xp` is an array-api-strict namespace. See Also -------- array_namespace is_numpy_namespace is_cupy_namespace is_torch_namespace is_ndonnx_namespace is_dask_namespace is_jax_namespace is_pydata_sparse_namespace """ return xp.__name__ == 'array_api_strict' def _check_api_version(api_version: str) -> None: if api_version in ['2021.12', '2022.12', '2023.12']: warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12") elif api_version is not None and api_version not in ['2021.12', '2022.12', '2023.12', '2024.12']: raise ValueError("Only the 2024.12 version of the array API specification is currently supported") def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. Parameters ---------- xs: arrays one or more arrays. xs can also be Python scalars (bool, int, float, complex, or None), which are ignored. api_version: str The newest version of the spec that you need support for (currently the compat library wrapped APIs support v2024.12). use_compat: bool or None If None (the default), the native namespace will be returned if it is already array API compatible, otherwise a compat wrapper is used. If True, the compat library wrapped library will be returned. If False, the native library namespace is returned. Returns ------- out: namespace The array API compatible namespace corresponding to the arrays in `xs`. Raises ------ TypeError If `xs` contains arrays from different array libraries or contains a non-array. Typical usage is to pass the arguments of a function to `array_namespace()` at the top of a function to get the corresponding array API namespace: .. code:: python def your_function(x, y): xp = array_api_compat.array_namespace(x, y) # Now use xp as the array library namespace return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) Wrapped array namespaces can also be imported directly. For example, `array_namespace(np.array(...))` will return `array_api_compat.numpy`. This function will also work for any array library not wrapped by array-api-compat if it explicitly defines `__array_namespace__ `__ (the wrapped namespace is always preferred if it exists). See Also -------- is_array_api_obj is_numpy_array is_cupy_array is_torch_array is_dask_array is_jax_array is_pydata_sparse_array """ if use_compat not in [None, True, False]: raise ValueError("use_compat must be None, True, or False") _use_compat = use_compat in [None, True] namespaces = set() for x in xs: if is_numpy_array(x): from .. import numpy as numpy_namespace import numpy as np if use_compat is True: _check_api_version(api_version) namespaces.add(numpy_namespace) elif use_compat is False: namespaces.add(np) else: # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API # compatible. namespaces.add(numpy_namespace) elif is_cupy_array(x): if _use_compat: _check_api_version(api_version) from .. import cupy as cupy_namespace namespaces.add(cupy_namespace) else: import cupy as cp namespaces.add(cp) elif is_torch_array(x): if _use_compat: _check_api_version(api_version) from .. import torch as torch_namespace namespaces.add(torch_namespace) else: import torch namespaces.add(torch) elif is_dask_array(x): if _use_compat: _check_api_version(api_version) from ..dask import array as dask_namespace namespaces.add(dask_namespace) else: import dask.array as da namespaces.add(da) elif is_jax_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("JAX does not have an array-api-compat wrapper") elif use_compat is False: import jax.numpy as jnp else: # JAX v0.4.32 and newer implements the array API directly in jax.numpy. # For older JAX versions, it is available via jax.experimental.array_api. import jax.numpy if hasattr(jax.numpy, "__array_api_version__"): jnp = jax.numpy else: import jax.experimental.array_api as jnp namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") else: import sparse # `sparse` is already an array namespace. We do not have a wrapper # submodule for it. namespaces.add(sparse) elif hasattr(x, '__array_namespace__'): if use_compat is True: raise ValueError("The given array does not have an array-api-compat wrapper") namespaces.add(x.__array_namespace__(api_version=api_version)) elif isinstance(x, (bool, int, float, complex, type(None))): continue else: # TODO: Support Python scalars? raise TypeError(f"{type(x).__name__} is not a supported array type") if not namespaces: raise TypeError("Unrecognized array input") if len(namespaces) != 1: raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") xp, = namespaces return xp # backwards compatibility alias get_namespace = array_namespace def _check_device(xp, device): if xp == sys.modules.get('numpy'): if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") # Placeholder object to represent the dask device # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: def __repr__(self): return "DASK_DEVICE" _DASK_DEVICE = _dask_device() # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. def device(x: Array, /) -> Device: """ Hardware device the array data resides on. This is equivalent to `x.device` according to the `standard `__. This helper is included because some array libraries either do not have the `device` attribute or include it with an incompatible API. Parameters ---------- x: array array instance from an array API compatible library. Returns ------- out: device a ``device`` object (see the `Device Support `__ section of the array API specification). Notes ----- For NumPy the device is always `"cpu"`. For Dask, the device is always a special `DASK_DEVICE` object. See Also -------- to_device : Move array data to a different device. """ if is_numpy_array(x): return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type if is_numpy_array(x._meta): # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE elif is_jax_array(x): # FIXME Jitted JAX arrays do not have a device attribute # https://github.com/jax-ml/jax/issues/26000 # Return None in this case. Note that this workaround breaks # the standard and will result in new arrays being created on the # default device instead of the same device as the input array(s). x_device = getattr(x, 'device', None) # Older JAX releases had .device() as a method, which has been replaced # with a property in accordance with the standard. if inspect.ismethod(x_device): return x_device() else: return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. x_device = getattr(x, 'device', None) if x_device is not None: return x_device # Everything but DOK has this attr. try: inner = x.data except AttributeError: return "cpu" # Return the device of the constituent array return device(inner) return x.device # Prevent shadowing, used below _device = device # Based on cupy.array_api.Array.to_device def _cupy_to_device(x, device, /, stream=None): import cupy as cp from cupy.cuda import Device as _Device from cupy.cuda import stream as stream_module from cupy_backends.cuda.api import runtime if device == x.device: return x elif device == "cpu": # allowing us to use `to_device(x, "cpu")` # is useful for portable test swapping between # host and device backends return x.get() elif not isinstance(device, _Device): raise ValueError(f"Unsupported device {device!r}") else: # see cupy/cupy#5985 for the reason how we handle device/stream here prev_device = runtime.getDevice() prev_stream: stream_module.Stream = None if stream is not None: prev_stream = stream_module.get_current_stream() # stream can be an int as specified in __dlpack__, or a CuPy stream if isinstance(stream, int): stream = cp.cuda.ExternalStream(stream) elif isinstance(stream, cp.cuda.Stream): pass else: raise ValueError('the input stream is not recognized') stream.use() try: runtime.setDevice(device.id) arr = x.copy() finally: runtime.setDevice(prev_device) if stream is not None: prev_stream.use() return arr def _torch_to_device(x, device, /, stream=None): if stream is not None: raise NotImplementedError return x.to(device) def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. This is equivalent to `x.to_device(device, stream=stream)` according to the `standard `__. This helper is included because some array libraries do not have the `to_device` method. Parameters ---------- x: array array instance from an array API compatible library. device: device a ``device`` object (see the `Device Support `__ section of the array API specification). stream: Optional[Union[int, Any]] stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. Returns ------- out: array an array with the same data and data type as ``x`` and located on the specified ``device``. Notes ----- For NumPy, this function effectively does nothing since the only supported device is the CPU. For CuPy, this method supports CuPy CUDA :external+cupy:class:`Device ` and :external+cupy:class:`Stream ` objects. For PyTorch, this is the same as :external+torch:meth:`x.to(device) ` (the ``stream`` argument is not supported in PyTorch). See Also -------- device : Hardware device the array data resides on. """ if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") if device == 'cpu': return x raise ValueError(f"Unsupported device {device!r}") elif is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): return _torch_to_device(x, device, stream=stream) elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? if device == 'cpu': return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x... import jax.experimental.array_api # noqa: F401 # ... but only on eager JAX. It won't work inside jax.jit. if not hasattr(x, "to_device"): return x return x.to_device(device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing. return x return x.to_device(device, stream=stream) def size(x: Array) -> int | None: """ Return the total number of elements of x. This is equivalent to `x.size` according to the `standard `__. This helper is included because PyTorch defines `size` in an :external+torch:meth:`incompatible way `. It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas the standard requires None. """ # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None out = math.prod(x.shape) # dask.array.Array.shape can contain NaN return None if math.isnan(out) else out def is_writeable_array(x: object) -> bool: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. Return False if `x` is not an array API compatible object. Warning ------- As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ if is_numpy_array(x): return x.flags.writeable if is_jax_array(x) or is_pydata_sparse_array(x): return False return is_array_api_obj(x) def is_lazy_array(x: object) -> bool: """Return True if x is potentially a future or it may be otherwise impossible or expensive to eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``. Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be cheap as long as the array has the right dtype and size. Note ---- This function errs on the side of caution for array types that may or may not be lazy, e.g. JAX arrays, by always returning True for them. """ if ( is_numpy_array(x) or is_cupy_array(x) or is_torch_array(x) or is_pydata_sparse_array(x) ): return False # **JAX note:** while it is possible to determine if you're inside or outside # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() # as we do below for unknown arrays, this is not recommended by JAX best practices. # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on. # This behaviour, while impossible to change without breaking backwards # compatibility, is highly detrimental to performance as the whole graph will end # up being computed multiple times. if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): return True if not is_array_api_obj(x): return False # Unknown Array API compatible object. Note that this test may have dire consequences # in terms of performance, e.g. for a lazy object that eagerly computes the graph # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array s = size(x) if s is None: return True xp = array_namespace(x) if s > 1: x = xp.reshape(x, (-1,))[0] # Cast to dtype=bool and deal with size 0 arrays x = xp.any(x) try: bool(x) return False # The Array API standard dictactes that __bool__ should raise TypeError if the # output cannot be defined. # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does. except Exception: return True __all__ = [ "array_namespace", "device", "get_namespace", "is_array_api_obj", "is_array_api_strict_namespace", "is_cupy_array", "is_cupy_namespace", "is_dask_array", "is_dask_namespace", "is_jax_array", "is_jax_namespace", "is_numpy_array", "is_numpy_namespace", "is_torch_array", "is_torch_namespace", "is_ndonnx_array", "is_ndonnx_namespace", "is_pydata_sparse_array", "is_pydata_sparse_namespace", "is_writeable_array", "is_lazy_array", "size", "to_device", ] _all_ignore = ['sys', 'math', 'inspect', 'warnings'] array-api-compat-1.11.2/array_api_compat/common/_linalg.py000066400000000000000000000137761476700770300235540ustar00rootroot00000000000000from __future__ import annotations from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: from typing import Literal, Optional, Tuple, Union from ._typing import ndarray import math import numpy as np if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: from numpy.core.numeric import normalize_axis_tuple from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: return xp.cross(x1, x2, axis=axis, **kwargs) def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): eigenvalues: ndarray eigenvectors: ndarray class QRResult(NamedTuple): Q: ndarray R: ndarray class SlogdetResult(NamedTuple): sign: ndarray logabsdet: ndarray class SVDResult(NamedTuple): U: ndarray S: ndarray Vh: ndarray # These functions are the same as their NumPy counterparts except they return # a namedtuple. def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', **kwargs) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): U = xp.conj(U) return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") S = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: # this is different from xp.linalg.matrix_rank, which does not # multiply the tolerance by the largest singular value. tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps return xp.linalg.pinv(x, rcond=rtol, **kwargs) # These functions are new in the array API spec def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: return xp.linalg.svd(x, compute_uv=False) def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done # on a single dimension. if axis is None: # Note: xp.linalg.norm() doesn't handle 0-D arrays _x = x.ravel() _axis = 0 elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. normalized_axis = normalize_axis_tuple(axis, x.ndim) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( (math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest])) _axis = 0 else: _x = x _axis = axis res = xp.linalg.norm(_x, axis=_axis, ord=ord) if keepdims: # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) for i in _axis: shape[i] = 1 res = xp.reshape(res, tuple(shape)) return res # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] array-api-compat-1.11.2/array_api_compat/common/_typing.py000066400000000000000000000007361476700770300236100ustar00rootroot00000000000000from __future__ import annotations __all__ = [ "NestedSequence", "SupportsBufferProtocol", ] from types import ModuleType from typing import ( Any, TypeVar, Protocol, ) _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... SupportsBufferProtocol = Any Array = Any Device = Any DType = Any Namespace = ModuleType array-api-compat-1.11.2/array_api_compat/cupy/000077500000000000000000000000001476700770300212475ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/cupy/__init__.py000066400000000000000000000006721476700770300233650ustar00rootroot00000000000000from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names from cupy import abs, max, min, round # noqa: F401 # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') from ..common._helpers import * # noqa: F401,F403 __array_api_version__ = '2024.12' array-api-compat-1.11.2/array_api_compat/cupy/_aliases.py000066400000000000000000000122061476700770300234020ustar00rootroot00000000000000from __future__ import annotations import cupy as cp from ..common import _aliases, _helpers from .._internal import get_xp from ._info import __array_namespace_info__ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol bool = cp.bool_ # Basic renames acos = cp.arccos acosh = cp.arccosh asin = cp.arcsin asinh = cp.arcsinh atan = cp.arctan atan2 = cp.arctan2 atanh = cp.arctanh bitwise_left_shift = cp.left_shift bitwise_invert = cp.invert bitwise_right_shift = cp.right_shift concat = cp.concatenate pow = cp.power arange = get_xp(cp)(_aliases.arange) empty = get_xp(cp)(_aliases.empty) empty_like = get_xp(cp)(_aliases.empty_like) eye = get_xp(cp)(_aliases.eye) full = get_xp(cp)(_aliases.full) full_like = get_xp(cp)(_aliases.full_like) linspace = get_xp(cp)(_aliases.linspace) ones = get_xp(cp)(_aliases.ones) ones_like = get_xp(cp)(_aliases.ones_like) zeros = get_xp(cp)(_aliases.zeros) zeros_like = get_xp(cp)(_aliases.zeros_like) UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult) UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult) UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult) unique_all = get_xp(cp)(_aliases.unique_all) unique_counts = get_xp(cp)(_aliases.unique_counts) unique_inverse = get_xp(cp)(_aliases.unique_inverse) unique_values = get_xp(cp)(_aliases.unique_values) std = get_xp(cp)(_aliases.std) var = get_xp(cp)(_aliases.var) cumulative_sum = get_xp(cp)(_aliases.cumulative_sum) cumulative_prod = get_xp(cp)(_aliases.cumulative_prod) clip = get_xp(cp)(_aliases.clip) permute_dims = get_xp(cp)(_aliases.permute_dims) reshape = get_xp(cp)(_aliases.reshape) argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) ceil = get_xp(cp)(_aliases.ceil) floor = get_xp(cp)(_aliases.floor) trunc = get_xp(cp)(_aliases.trunc) matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) sign = get_xp(cp)(_aliases.sign) _copy_default = object() # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: Union[ ndarray, bool, int, float, NestedSequence[bool | int | float], SupportsBufferProtocol, ], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = _copy_default, **kwargs, ) -> ndarray: """ Array API compatibility wrapper for asarray(). See the corresponding documentation in the array library and/or the array API specification for more details. """ with cp.cuda.Device(device): # cupy is like NumPy 1.26 (except without _CopyMode). See the comments # in asarray in numpy/_aliases.py. if copy is not _copy_default: # A future version of CuPy will change the meaning of copy=False # to mean no-copy. We don't know for certain what version it will # be yet, so to avoid breaking that version, we use a different # default value for copy so asarray(obj) with no copy kwarg will # always do the copy-if-needed behavior. # This will still need to be updated to remove the # NotImplementedError for copy=False, but at least this won't # break the default or existing behavior. if copy is None: copy = False elif copy is False: raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") kwargs['copy'] = copy return cp.array(obj, dtype=dtype, **kwargs) def astype( x: ndarray, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None, ) -> ndarray: if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) return out.copy() if copy and out is x else out # cupy.count_nonzero does not have keepdims def count_nonzero( x: ndarray, axis=None, keepdims=False ) -> ndarray: result = cp.count_nonzero(x, axis) if keepdims: if axis is None: return cp.reshape(result, [1]*x.ndim) return cp.expand_dims(result, axis) return result # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): vecdot = cp.vecdot else: vecdot = get_xp(cp)(_aliases.vecdot) if hasattr(cp, 'isdtype'): isdtype = cp.isdtype else: isdtype = get_xp(cp)(_aliases.isdtype) if hasattr(cp, 'unstack'): unstack = cp.unstack else: unstack = get_xp(cp)(_aliases.unstack) __all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] _all_ignore = ['cp', 'get_xp'] array-api-compat-1.11.2/array_api_compat/cupy/_info.py000066400000000000000000000231211476700770300227120ustar00rootroot00000000000000""" Array API Inspection namespace This is the namespace for inspection functions as defined by the array API standard. See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. """ from cupy import ( dtype, cuda, bool_ as bool, intp, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128, ) class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. The array API inspection namespace defines the following functions: - capabilities() - default_device() - default_dtypes() - dtypes() - devices() See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. Returns ------- info : ModuleType The array API inspection namespace for CuPy. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, 'integral': cupy.int64, 'indexing': cupy.int64} """ __module__ = 'cupy' def capabilities(self): """ Return a dictionary of array API library capabilities. The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library supports boolean indexing. Always ``True`` for CuPy. - **"data-dependent shapes"**: boolean indicating whether an array library supports data-dependent output shapes. Always ``True`` for CuPy. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html for more details. See Also -------- __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- capabilities : dict A dictionary of array API library capabilities. Examples -------- >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, 'data-dependent shapes': True} """ return { "boolean indexing": True, "data-dependent shapes": True, # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } def default_device(self): """ The default device used for new CuPy arrays. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- device : str The default device used for new CuPy arrays. Examples -------- >>> info = xp.__array_namespace_info__() >>> info.default_device() Device(0) """ return cuda.Device(0) def default_dtypes(self, *, device=None): """ The default data types used for new CuPy arrays. For CuPy, this always returns the following dictionary: - **"real floating"**: ``cupy.float64`` - **"complex floating"**: ``cupy.complex128`` - **"integral"**: ``cupy.intp`` - **"indexing"**: ``cupy.intp`` Parameters ---------- device : str, optional The device to get the default data types for. Returns ------- dtypes : dict A dictionary describing the default data types used for new CuPy arrays. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.dtypes, __array_namespace_info__.devices Examples -------- >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, 'integral': cupy.int64, 'indexing': cupy.int64} """ # TODO: Does this depend on device? return { "real floating": dtype(float64), "complex floating": dtype(complex128), "integral": dtype(intp), "indexing": dtype(intp), } def dtypes(self, *, device=None, kind=None): """ The array API data types supported by CuPy. Note that this function only returns data types that are defined by the array API. Parameters ---------- device : str, optional The device to get the data types for. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. If a tuple, a dictionary containing the union of the given kinds is returned. The following kinds are supported: - ``'bool'``: boolean data types (i.e., ``bool``). - ``'signed integer'``: signed integer data types (i.e., ``int8``, ``int16``, ``int32``, ``int64``). - ``'unsigned integer'``: unsigned integer data types (i.e., ``uint8``, ``uint16``, ``uint32``, ``uint64``). - ``'integral'``: integer data types. Shorthand for ``('signed integer', 'unsigned integer')``. - ``'real floating'``: real-valued floating-point data types (i.e., ``float32``, ``float64``). - ``'complex floating'``: complex floating-point data types (i.e., ``complex64``, ``complex128``). - ``'numeric'``: numeric data types. Shorthand for ``('integral', 'real floating', 'complex floating')``. Returns ------- dtypes : dict A dictionary mapping the names of data types to the corresponding CuPy data types. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.devices Examples -------- >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': cupy.int8, 'int16': cupy.int16, 'int32': cupy.int32, 'int64': cupy.int64} """ # TODO: Does this depend on device? if kind is None: return { "bool": dtype(bool), "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), "float32": dtype(float32), "float64": dtype(float64), "complex64": dtype(complex64), "complex128": dtype(complex128), } if kind == "bool": return {"bool": bool} if kind == "signed integer": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), } if kind == "unsigned integer": return { "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), } if kind == "integral": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), } if kind == "real floating": return { "float32": dtype(float32), "float64": dtype(float64), } if kind == "complex floating": return { "complex64": dtype(complex64), "complex128": dtype(complex128), } if kind == "numeric": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), "float32": dtype(float32), "float64": dtype(float64), "complex64": dtype(complex64), "complex128": dtype(complex128), } if isinstance(kind, tuple): res = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") def devices(self): """ The devices supported by CuPy. Returns ------- devices : list of str The devices supported by CuPy. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes """ return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())] array-api-compat-1.11.2/array_api_compat/cupy/_typing.py000066400000000000000000000011511476700770300232700ustar00rootroot00000000000000from __future__ import annotations __all__ = [ "ndarray", "Device", "Dtype", ] import sys from typing import ( Union, TYPE_CHECKING, ) from cupy import ( ndarray, dtype, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, ) from cupy.cuda.device import Device if TYPE_CHECKING or sys.version_info >= (3, 9): Dtype = dtype[Union[ int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, ]] else: Dtype = dtype array-api-compat-1.11.2/array_api_compat/cupy/fft.py000066400000000000000000000015121476700770300223770ustar00rootroot00000000000000from cupy.fft import * # noqa: F403 # cupy.fft doesn't have __all__. If it is added, replace this with # # from cupy.fft import __all__ as linalg_all _n = {} exec('from cupy.fft import *', _n) del _n['__builtins__'] fft_all = list(_n) del _n from ..common import _fft from .._internal import get_xp import cupy as cp fft = get_xp(cp)(_fft.fft) ifft = get_xp(cp)(_fft.ifft) fftn = get_xp(cp)(_fft.fftn) ifftn = get_xp(cp)(_fft.ifftn) rfft = get_xp(cp)(_fft.rfft) irfft = get_xp(cp)(_fft.irfft) rfftn = get_xp(cp)(_fft.rfftn) irfftn = get_xp(cp)(_fft.irfftn) hfft = get_xp(cp)(_fft.hfft) ihfft = get_xp(cp)(_fft.ihfft) fftfreq = get_xp(cp)(_fft.fftfreq) rfftfreq = get_xp(cp)(_fft.rfftfreq) fftshift = get_xp(cp)(_fft.fftshift) ifftshift = get_xp(cp)(_fft.ifftshift) __all__ = fft_all + _fft.__all__ del get_xp del cp del fft_all del _fft array-api-compat-1.11.2/array_api_compat/cupy/linalg.py000066400000000000000000000026441476700770300230750ustar00rootroot00000000000000from cupy.linalg import * # noqa: F403 # cupy.linalg doesn't have __all__. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all _n = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] linalg_all = list(_n) del _n from ..common import _linalg from .._internal import get_xp import cupy as cp # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 cross = get_xp(cp)(_linalg.cross) outer = get_xp(cp)(_linalg.outer) EighResult = _linalg.EighResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult eigh = get_xp(cp)(_linalg.eigh) qr = get_xp(cp)(_linalg.qr) slogdet = get_xp(cp)(_linalg.slogdet) svd = get_xp(cp)(_linalg.svd) cholesky = get_xp(cp)(_linalg.cholesky) matrix_rank = get_xp(cp)(_linalg.matrix_rank) pinv = get_xp(cp)(_linalg.pinv) matrix_norm = get_xp(cp)(_linalg.matrix_norm) svdvals = get_xp(cp)(_linalg.svdvals) diagonal = get_xp(cp)(_linalg.diagonal) trace = get_xp(cp)(_linalg.trace) # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp.linalg, 'vector_norm'): vector_norm = cp.linalg.vector_norm else: vector_norm = get_xp(cp)(_linalg.vector_norm) __all__ = linalg_all + _linalg.__all__ del get_xp del cp del linalg_all del _linalg array-api-compat-1.11.2/array_api_compat/dask/000077500000000000000000000000001476700770300212115ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/dask/__init__.py000066400000000000000000000000001476700770300233100ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/dask/array/000077500000000000000000000000001476700770300223275ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/dask/array/__init__.py000066400000000000000000000003621476700770300244410ustar00rootroot00000000000000from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 __array_api_version__ = '2024.12' __import__(__package__ + '.linalg') __import__(__package__ + '.fft') array-api-compat-1.11.2/array_api_compat/dask/array/_aliases.py000066400000000000000000000237541476700770300244740ustar00rootroot00000000000000from __future__ import annotations from typing import Callable from ...common import _aliases, array_namespace from ..._internal import get_xp from ._info import __array_namespace_info__ import numpy as np from numpy import ( # Dtypes iinfo, finfo, bool_ as bool, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, complex64, complex128, can_cast, result_type, ) from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union from ...common._typing import ( Device, Dtype, Array, NestedSequence, SupportsBufferProtocol, ) import dask.array as da isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) # da.astype doesn't respect copy=True def astype( x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None, ) -> Array: """ Array API compatibility wrapper for astype(). See the corresponding documentation in the array library and/or the array API specification for more details. """ # TODO: respect device keyword? if not copy and dtype == x.dtype: return x x = x.astype(dtype) return x.copy() if copy else x # Common aliases # This arange func is modified from the common one to # not pass stop/step as keyword arguments, which will cause # an error with dask def arange( start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> Array: """ Array API compatibility wrapper for arange(). See the corresponding documentation in the array library and/or the array API specification for more details. """ # TODO: respect device keyword? args = [start] if stop is not None: args.append(stop) else: # stop is None, so start is actually stop # prepend the default value for start which is 0 args.insert(0, 0) args.append(step) return da.arange(*args, dtype=dtype, **kwargs) eye = get_xp(da)(_aliases.eye) linspace = get_xp(da)(_aliases.linspace) UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) unique_all = get_xp(da)(_aliases.unique_all) unique_counts = get_xp(da)(_aliases.unique_counts) unique_inverse = get_xp(da)(_aliases.unique_inverse) unique_values = get_xp(da)(_aliases.unique_values) permute_dims = get_xp(da)(_aliases.permute_dims) std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) cumulative_sum = get_xp(da)(_aliases.cumulative_sum) cumulative_prod = get_xp(da)(_aliases.cumulative_prod) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) full_like = get_xp(da)(_aliases.full_like) ones = get_xp(da)(_aliases.ones) ones_like = get_xp(da)(_aliases.ones_like) zeros = get_xp(da)(_aliases.zeros) zeros_like = get_xp(da)(_aliases.zeros_like) reshape = get_xp(da)(_aliases.reshape) matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: Union[ Array, bool, int, float, NestedSequence[bool | int | float], SupportsBufferProtocol, ], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, ) -> Array: """ Array API compatibility wrapper for asarray(). See the corresponding documentation in the array library and/or the array API specification for more details. """ # TODO: respect device keyword? if isinstance(obj, da.Array): if dtype is not None and dtype != obj.dtype: if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) return obj.copy() if copy else obj if copy is False: raise NotImplementedError( "Unable to avoid copy when converting a non-dask object to dask" ) # copy=None to be uniform across dask < 2024.12 and >= 2024.12 # see https://github.com/dask/dask/pull/11524/ obj = np.array(obj, dtype=dtype, copy=True) return da.from_array(obj) from dask.array import ( # Element wise aliases arccos as acos, arccosh as acosh, arcsin as asin, arcsinh as asinh, arctan as atan, arctan2 as atan2, arctanh as atanh, left_shift as bitwise_left_shift, right_shift as bitwise_right_shift, invert as bitwise_invert, power as pow, # Other concatenate as concat, ) # dask.array.clip does not work unless all three arguments are provided. # Furthermore, the masking workaround in common._aliases.clip cannot work with # dask (meaning uint64 promoting to float64 is going to just be unfixed for # now). def clip( x: Array, /, min: Optional[Union[int, float, Array]] = None, max: Optional[Union[int, float, Array]] = None, ) -> Array: """ Array API compatibility wrapper for clip(). See the corresponding documentation in the array library and/or the array API specification for more details. """ def _isscalar(a): return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape # TODO: This won't handle dask unknown shapes result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) if min is not None: min = da.broadcast_to(da.asarray(min), result_shape) if max is not None: max = da.broadcast_to(da.asarray(max), result_shape) if min is None and max is None: return da.positive(x) if min is None: return astype(da.minimum(x, max), x.dtype) if max is None: return astype(da.maximum(x, min), x.dtype) return astype(da.minimum(da.maximum(x, min), max), x.dtype) def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]: """ Make sure that Array is not broken into multiple chunks along axis. Returns ------- x : Array The input Array with a single chunk along axis. restore : Callable[Array, Array] function to apply to the output to rechunk it back into reasonable chunks """ if axis < 0: axis += x.ndim if x.numblocks[axis] < 2: return x, lambda x: x # Break chunks on other axes in an attempt to keep chunk size low x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)}) # Rather than reconstructing the original chunks, which can be a # very expensive affair, just break down oversized chunks without # incurring in any transfers over the network. # This has the downside of a risk of overchunking if the array is # then used in operations against other arrays that match the # original chunking pattern. return x, lambda x: x.rechunk() def sort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: """ Array API compatibility layer around the lack of sort() in Dask. Warnings -------- This function temporarily rechunks the array along `axis` to a single chunk. This can be extremely inefficient and can lead to out-of-memory errors. See the corresponding documentation in the array library and/or the array API specification for more details. """ x, restore = _ensure_single_chunk(x, axis) meta_xp = array_namespace(x._meta) x = da.map_blocks( meta_xp.sort, x, axis=axis, meta=x._meta, dtype=x.dtype, descending=descending, stable=stable, ) return restore(x) def argsort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: """ Array API compatibility layer around the lack of argsort() in Dask. See the corresponding documentation in the array library and/or the array API specification for more details. Warnings -------- This function temporarily rechunks the array along `axis` into a single chunk. This can be extremely inefficient and can lead to out-of-memory errors. """ x, restore = _ensure_single_chunk(x, axis) meta_xp = array_namespace(x._meta) dtype = meta_xp.argsort(x._meta).dtype meta = meta_xp.astype(x._meta, dtype) x = da.map_blocks( meta_xp.argsort, x, axis=axis, meta=meta, dtype=dtype, descending=descending, stable=stable, ) return restore(x) # dask.array.count_nonzero does not have keepdims def count_nonzero( x: Array, axis=None, keepdims=False ) -> Array: result = da.count_nonzero(x, axis) if keepdims: if axis is None: return da.reshape(result, [1]*x.ndim) return da.expand_dims(result, axis) return result __all__ = _aliases.__all__ + [ '__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'count_nonzero', 'result_type'] _all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] array-api-compat-1.11.2/array_api_compat/dask/array/_info.py000066400000000000000000000242561476700770300240040ustar00rootroot00000000000000""" Array API Inspection namespace This is the namespace for inspection functions as defined by the array API standard. See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. """ from numpy import ( dtype, bool_ as bool, intp, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128, ) from ...common._helpers import _DASK_DEVICE class __array_namespace_info__: """ Get the array API inspection namespace for Dask. The array API inspection namespace defines the following functions: - capabilities() - default_device() - default_dtypes() - dtypes() - devices() See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. Returns ------- info : ModuleType The array API inspection namespace for Dask. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, 'integral': dask.int64, 'indexing': dask.int64} """ __module__ = 'dask.array' def capabilities(self): """ Return a dictionary of array API library capabilities. The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library supports boolean indexing. Always ``False`` for Dask. - **"data-dependent shapes"**: boolean indicating whether an array library supports data-dependent output shapes. Always ``False`` for Dask. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html for more details. See Also -------- __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- capabilities : dict A dictionary of array API library capabilities. Examples -------- >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, 'data-dependent shapes': True} """ return { "boolean indexing": False, "data-dependent shapes": False, # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } def default_device(self): """ The default device used for new Dask arrays. For Dask, this always returns ``'cpu'``. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- device : str The default device used for new Dask arrays. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_device() 'cpu' """ return "cpu" def default_dtypes(self, *, device=None): """ The default data types used for new Dask arrays. For Dask, this always returns the following dictionary: - **"real floating"**: ``numpy.float64`` - **"complex floating"**: ``numpy.complex128`` - **"integral"**: ``numpy.intp`` - **"indexing"**: ``numpy.intp`` Parameters ---------- device : str, optional The device to get the default data types for. Returns ------- dtypes : dict A dictionary describing the default data types used for new Dask arrays. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.dtypes, __array_namespace_info__.devices Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, 'integral': dask.int64, 'indexing': dask.int64} """ if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' f' {device}' ) return { "real floating": dtype(float64), "complex floating": dtype(complex128), "integral": dtype(intp), "indexing": dtype(intp), } def dtypes(self, *, device=None, kind=None): """ The array API data types supported by Dask. Note that this function only returns data types that are defined by the array API. Parameters ---------- device : str, optional The device to get the data types for. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. If a tuple, a dictionary containing the union of the given kinds is returned. The following kinds are supported: - ``'bool'``: boolean data types (i.e., ``bool``). - ``'signed integer'``: signed integer data types (i.e., ``int8``, ``int16``, ``int32``, ``int64``). - ``'unsigned integer'``: unsigned integer data types (i.e., ``uint8``, ``uint16``, ``uint32``, ``uint64``). - ``'integral'``: integer data types. Shorthand for ``('signed integer', 'unsigned integer')``. - ``'real floating'``: real-valued floating-point data types (i.e., ``float32``, ``float64``). - ``'complex floating'``: complex floating-point data types (i.e., ``complex64``, ``complex128``). - ``'numeric'``: numeric data types. Shorthand for ``('integral', 'real floating', 'complex floating')``. Returns ------- dtypes : dict A dictionary mapping the names of data types to the corresponding Dask data types. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.devices Examples -------- >>> info = np.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': dask.int8, 'int16': dask.int16, 'int32': dask.int32, 'int64': dask.int64} """ if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' f' {device}' ) if kind is None: return { "bool": dtype(bool), "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), "float32": dtype(float32), "float64": dtype(float64), "complex64": dtype(complex64), "complex128": dtype(complex128), } if kind == "bool": return {"bool": bool} if kind == "signed integer": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), } if kind == "unsigned integer": return { "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), } if kind == "integral": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), } if kind == "real floating": return { "float32": dtype(float32), "float64": dtype(float64), } if kind == "complex floating": return { "complex64": dtype(complex64), "complex128": dtype(complex128), } if kind == "numeric": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), "float32": dtype(float32), "float64": dtype(float64), "complex64": dtype(complex64), "complex128": dtype(complex128), } if isinstance(kind, tuple): res = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") def devices(self): """ The devices supported by Dask. For Dask, this always returns ``['cpu', DASK_DEVICE]``. Returns ------- devices : list of str The devices supported by Dask. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes Examples -------- >>> info = np.__array_namespace_info__() >>> info.devices() ['cpu', DASK_DEVICE] """ return ["cpu", _DASK_DEVICE] array-api-compat-1.11.2/array_api_compat/dask/array/fft.py000066400000000000000000000010511476700770300234550ustar00rootroot00000000000000from dask.array.fft import * # noqa: F403 # dask.array.fft doesn't have __all__. If it is added, replace this with # # from dask.array.fft import __all__ as linalg_all _n = {} exec('from dask.array.fft import *', _n) del _n['__builtins__'] fft_all = list(_n) del _n from ...common import _fft from ..._internal import get_xp import dask.array as da fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) __all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] del get_xp del da del fft_all del _fft array-api-compat-1.11.2/array_api_compat/dask/array/linalg.py000066400000000000000000000046111476700770300241510ustar00rootroot00000000000000from __future__ import annotations from ...common import _linalg from ..._internal import get_xp # Exports from dask.array.linalg import * # noqa: F403 from dask.array import outer # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot from ._aliases import matrix_transpose, vecdot import dask.array as da from typing import TYPE_CHECKING if TYPE_CHECKING: from ...common._typing import Array from typing import Literal # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) del _n['__builtins__'] if 'annotations' in _n: del _n['annotations'] linalg_all = list(_n) del _n EighResult = _linalg.EighResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult # TODO: use the QR wrapper once dask # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', **kwargs) -> QRResult: if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) trace = get_xp(da)(_linalg.trace) cholesky = get_xp(da)(_linalg.cholesky) matrix_rank = get_xp(da)(_linalg.matrix_rank) matrix_norm = get_xp(da)(_linalg.matrix_norm) # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) __all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", "matrix_transpose", "vecdot", "EighResult", "QRResult", "SlogdetResult", "SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm", "svdvals", "vector_norm", "diagonal"] _all_ignore = ['get_xp', 'da', 'linalg_all'] array-api-compat-1.11.2/array_api_compat/numpy/000077500000000000000000000000001476700770300214375ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/numpy/__init__.py000066400000000000000000000014771476700770300235610ustar00rootroot00000000000000from numpy import * # noqa: F403 # from numpy import * doesn't overwrite these builtin names from numpy import abs, max, min, round # noqa: F401 # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do # # from . import linalg # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') __import__(__package__ + '.fft') from .linalg import matrix_transpose, vecdot # noqa: F401 from ..common._helpers import * # noqa: F403 try: # Used in asarray(). Not present in older versions. from numpy import _CopyMode # noqa: F401 except ImportError: pass __array_api_version__ = '2024.12' array-api-compat-1.11.2/array_api_compat/numpy/_aliases.py000066400000000000000000000117261476700770300236000ustar00rootroot00000000000000from __future__ import annotations from ..common import _aliases from .._internal import get_xp from ._info import __array_namespace_info__ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol import numpy as np bool = np.bool_ # Basic renames acos = np.arccos acosh = np.arccosh asin = np.arcsin asinh = np.arcsinh atan = np.arctan atan2 = np.arctan2 atanh = np.arctanh bitwise_left_shift = np.left_shift bitwise_invert = np.invert bitwise_right_shift = np.right_shift concat = np.concatenate pow = np.power arange = get_xp(np)(_aliases.arange) empty = get_xp(np)(_aliases.empty) empty_like = get_xp(np)(_aliases.empty_like) eye = get_xp(np)(_aliases.eye) full = get_xp(np)(_aliases.full) full_like = get_xp(np)(_aliases.full_like) linspace = get_xp(np)(_aliases.linspace) ones = get_xp(np)(_aliases.ones) ones_like = get_xp(np)(_aliases.ones_like) zeros = get_xp(np)(_aliases.zeros) zeros_like = get_xp(np)(_aliases.zeros_like) UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult) UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult) UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult) unique_all = get_xp(np)(_aliases.unique_all) unique_counts = get_xp(np)(_aliases.unique_counts) unique_inverse = get_xp(np)(_aliases.unique_inverse) unique_values = get_xp(np)(_aliases.unique_values) std = get_xp(np)(_aliases.std) var = get_xp(np)(_aliases.var) cumulative_sum = get_xp(np)(_aliases.cumulative_sum) cumulative_prod = get_xp(np)(_aliases.cumulative_prod) clip = get_xp(np)(_aliases.clip) permute_dims = get_xp(np)(_aliases.permute_dims) reshape = get_xp(np)(_aliases.reshape) argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) def _supports_buffer_protocol(obj): try: memoryview(obj) except TypeError: return False return True # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( obj: Union[ ndarray, bool, int, float, NestedSequence[bool | int | float], SupportsBufferProtocol, ], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, **kwargs, ) -> ndarray: """ Array API compatibility wrapper for asarray(). See the corresponding documentation in the array library and/or the array API specification for more details. """ if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") if hasattr(np, '_CopyMode'): if copy is None: copy = np._CopyMode.IF_NEEDED elif copy is False: copy = np._CopyMode.NEVER elif copy is True: copy = np._CopyMode.ALWAYS else: # Not present in older NumPys. In this case, we cannot really support # copy=False. if copy is False: raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") return np.array(obj, copy=copy, dtype=dtype, **kwargs) def astype( x: ndarray, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None, ) -> ndarray: return x.astype(dtype=dtype, copy=copy) # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 def count_nonzero( x : ndarray, axis=None, keepdims=False ) -> ndarray: result = np.count_nonzero(x, axis=axis, keepdims=keepdims) if axis is None and not keepdims: return np.asarray(result) return result # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, 'vecdot'): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) if hasattr(np, 'isdtype'): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) if hasattr(np, 'unstack'): unstack = np.unstack else: unstack = get_xp(np)(_aliases.unstack) __all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow'] _all_ignore = ['np', 'get_xp'] array-api-compat-1.11.2/array_api_compat/numpy/_info.py000066400000000000000000000242241476700770300231070ustar00rootroot00000000000000""" Array API Inspection namespace This is the namespace for inspection functions as defined by the array API standard. See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. """ from numpy import ( dtype, bool_ as bool, intp, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128, ) class __array_namespace_info__: """ Get the array API inspection namespace for NumPy. The array API inspection namespace defines the following functions: - capabilities() - default_device() - default_dtypes() - dtypes() - devices() See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. Returns ------- info : ModuleType The array API inspection namespace for NumPy. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, 'integral': numpy.int64, 'indexing': numpy.int64} """ __module__ = 'numpy' def capabilities(self): """ Return a dictionary of array API library capabilities. The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library supports boolean indexing. Always ``True`` for NumPy. - **"data-dependent shapes"**: boolean indicating whether an array library supports data-dependent output shapes. Always ``True`` for NumPy. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html for more details. See Also -------- __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- capabilities : dict A dictionary of array API library capabilities. Examples -------- >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, 'data-dependent shapes': True} """ return { "boolean indexing": True, "data-dependent shapes": True, # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } def default_device(self): """ The default device used for new NumPy arrays. For NumPy, this always returns ``'cpu'``. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- device : str The default device used for new NumPy arrays. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_device() 'cpu' """ return "cpu" def default_dtypes(self, *, device=None): """ The default data types used for new NumPy arrays. For NumPy, this always returns the following dictionary: - **"real floating"**: ``numpy.float64`` - **"complex floating"**: ``numpy.complex128`` - **"integral"**: ``numpy.intp`` - **"indexing"**: ``numpy.intp`` Parameters ---------- device : str, optional The device to get the default data types for. For NumPy, only ``'cpu'`` is allowed. Returns ------- dtypes : dict A dictionary describing the default data types used for new NumPy arrays. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.dtypes, __array_namespace_info__.devices Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, 'integral': numpy.int64, 'indexing': numpy.int64} """ if device not in ["cpu", None]: raise ValueError( 'Device not understood. Only "cpu" is allowed, but received:' f' {device}' ) return { "real floating": dtype(float64), "complex floating": dtype(complex128), "integral": dtype(intp), "indexing": dtype(intp), } def dtypes(self, *, device=None, kind=None): """ The array API data types supported by NumPy. Note that this function only returns data types that are defined by the array API. Parameters ---------- device : str, optional The device to get the data types for. For NumPy, only ``'cpu'`` is allowed. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. If a tuple, a dictionary containing the union of the given kinds is returned. The following kinds are supported: - ``'bool'``: boolean data types (i.e., ``bool``). - ``'signed integer'``: signed integer data types (i.e., ``int8``, ``int16``, ``int32``, ``int64``). - ``'unsigned integer'``: unsigned integer data types (i.e., ``uint8``, ``uint16``, ``uint32``, ``uint64``). - ``'integral'``: integer data types. Shorthand for ``('signed integer', 'unsigned integer')``. - ``'real floating'``: real-valued floating-point data types (i.e., ``float32``, ``float64``). - ``'complex floating'``: complex floating-point data types (i.e., ``complex64``, ``complex128``). - ``'numeric'``: numeric data types. Shorthand for ``('integral', 'real floating', 'complex floating')``. Returns ------- dtypes : dict A dictionary mapping the names of data types to the corresponding NumPy data types. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.devices Examples -------- >>> info = np.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, 'int32': numpy.int32, 'int64': numpy.int64} """ if device not in ["cpu", None]: raise ValueError( 'Device not understood. Only "cpu" is allowed, but received:' f' {device}' ) if kind is None: return { "bool": dtype(bool), "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), "float32": dtype(float32), "float64": dtype(float64), "complex64": dtype(complex64), "complex128": dtype(complex128), } if kind == "bool": return {"bool": bool} if kind == "signed integer": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), } if kind == "unsigned integer": return { "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), } if kind == "integral": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), } if kind == "real floating": return { "float32": dtype(float32), "float64": dtype(float64), } if kind == "complex floating": return { "complex64": dtype(complex64), "complex128": dtype(complex128), } if kind == "numeric": return { "int8": dtype(int8), "int16": dtype(int16), "int32": dtype(int32), "int64": dtype(int64), "uint8": dtype(uint8), "uint16": dtype(uint16), "uint32": dtype(uint32), "uint64": dtype(uint64), "float32": dtype(float32), "float64": dtype(float64), "complex64": dtype(complex64), "complex128": dtype(complex128), } if isinstance(kind, tuple): res = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") def devices(self): """ The devices supported by NumPy. For NumPy, this always returns ``['cpu']``. Returns ------- devices : list of str The devices supported by NumPy. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes Examples -------- >>> info = np.__array_namespace_info__() >>> info.devices() ['cpu'] """ return ["cpu"] array-api-compat-1.11.2/array_api_compat/numpy/_typing.py000066400000000000000000000011521476700770300234610ustar00rootroot00000000000000from __future__ import annotations __all__ = [ "ndarray", "Device", "Dtype", ] import sys from typing import ( Literal, Union, TYPE_CHECKING, ) from numpy import ( ndarray, dtype, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, ) Device = Literal["cpu"] if TYPE_CHECKING or sys.version_info >= (3, 9): Dtype = dtype[Union[ int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, ]] else: Dtype = dtype array-api-compat-1.11.2/array_api_compat/numpy/fft.py000066400000000000000000000012471476700770300225740ustar00rootroot00000000000000from numpy.fft import * # noqa: F403 from numpy.fft import __all__ as fft_all from ..common import _fft from .._internal import get_xp import numpy as np fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) fftn = get_xp(np)(_fft.fftn) ifftn = get_xp(np)(_fft.ifftn) rfft = get_xp(np)(_fft.rfft) irfft = get_xp(np)(_fft.irfft) rfftn = get_xp(np)(_fft.rfftn) irfftn = get_xp(np)(_fft.irfftn) hfft = get_xp(np)(_fft.hfft) ihfft = get_xp(np)(_fft.ihfft) fftfreq = get_xp(np)(_fft.fftfreq) rfftfreq = get_xp(np)(_fft.rfftfreq) fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) __all__ = fft_all + _fft.__all__ del get_xp del np del fft_all del _fft array-api-compat-1.11.2/array_api_compat/numpy/linalg.py000066400000000000000000000062701476700770300232640ustar00rootroot00000000000000from numpy.linalg import * # noqa: F403 from numpy.linalg import __all__ as linalg_all import numpy as _np from ..common import _linalg from .._internal import get_xp # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 import numpy as np cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) EighResult = _linalg.EighResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult eigh = get_xp(np)(_linalg.eigh) qr = get_xp(np)(_linalg.qr) slogdet = get_xp(np)(_linalg.slogdet) svd = get_xp(np)(_linalg.svd) cholesky = get_xp(np)(_linalg.cholesky) matrix_rank = get_xp(np)(_linalg.matrix_rank) pinv = get_xp(np)(_linalg.pinv) matrix_norm = get_xp(np)(_linalg.matrix_norm) svdvals = get_xp(np)(_linalg.svdvals) diagonal = get_xp(np)(_linalg.diagonal) trace = get_xp(np)(_linalg.trace) # Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a # vector when it is exactly 1-dimensional. All other cases treat x2 as a stack # of matrices. The np.linalg.solve behavior of allowing stacks of both # matrices and vectors is ambiguous c.f. # https://github.com/numpy/numpy/issues/15349 and # https://github.com/data-apis/array-api/issues/285. # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. # This code is here instead of in common because it is numpy specific. Also # note that CuPy's solve() does not currently support broadcasting (see # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: try: from numpy.linalg._linalg import ( _makearray, _assert_stacked_2d, _assert_stacked_square, _commonType, isComplexType, _raise_linalgerror_singular ) except ImportError: from numpy.linalg.linalg import ( _makearray, _assert_stacked_2d, _assert_stacked_square, _commonType, isComplexType, _raise_linalgerror_singular ) from numpy.linalg import _umath_linalg x1, _ = _makearray(x1) _assert_stacked_2d(x1) _assert_stacked_square(x1) x2, wrap = _makearray(x2) t, result_t = _commonType(x1, x2) # This part is different from np.linalg.solve if x2.ndim == 1: gufunc = _umath_linalg.solve1 else: gufunc = _umath_linalg.solve # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. signature = 'DD->D' if isComplexType(t) else 'dd->d' with _np.errstate(call=_raise_linalgerror_singular, invalid='call', over='ignore', divide='ignore', under='ignore'): r = gufunc(x1, x2, signature=signature) return wrap(r.astype(result_t, copy=False)) # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np.linalg, 'vector_norm'): vector_norm = np.linalg.vector_norm else: vector_norm = get_xp(np)(_linalg.vector_norm) __all__ = linalg_all + _linalg.__all__ + ['solve'] del get_xp del np del linalg_all del _linalg array-api-compat-1.11.2/array_api_compat/torch/000077500000000000000000000000001476700770300214065ustar00rootroot00000000000000array-api-compat-1.11.2/array_api_compat/torch/__init__.py000066400000000000000000000011171476700770300235170ustar00rootroot00000000000000from torch import * # noqa: F403 # Several names are not included in the above import * import torch for n in dir(torch): if (n.startswith('_') or n.endswith('_') or 'cuda' in n or 'cpu' in n or 'backward' in n): continue exec(n + ' = torch.' + n) # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') from ..common._helpers import * # noqa: F403 __array_api_version__ = '2024.12' array-api-compat-1.11.2/array_api_compat/torch/_aliases.py000066400000000000000000000725511476700770300235520ustar00rootroot00000000000000from __future__ import annotations from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any from ..common import _aliases from .._internal import get_xp from ._info import __array_namespace_info__ import torch from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import List, Optional, Sequence, Tuple, Union from ..common._typing import Device from torch import dtype as Dtype array = torch.Tensor _int_dtypes = { torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, } try: # torch >=2.3 _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64} except AttributeError: pass _array_api_dtypes = { torch.bool, *_int_dtypes, torch.float32, torch.float64, torch.complex64, torch.complex128, } _promotion_table = { # bool (torch.bool, torch.bool): torch.bool, # ints (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, (torch.int16, torch.int8): torch.int16, (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, (torch.int32, torch.int8): torch.int32, (torch.int32, torch.int16): torch.int32, (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, (torch.int64, torch.int8): torch.int64, (torch.int64, torch.int16): torch.int64, (torch.int64, torch.int32): torch.int64, (torch.int64, torch.int64): torch.int64, # uints (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) (torch.int8, torch.uint8): torch.int16, (torch.int16, torch.uint8): torch.int16, (torch.int32, torch.uint8): torch.int32, (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, (torch.float64, torch.float32): torch.float64, (torch.float64, torch.float64): torch.float64, # complexes (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, (torch.complex128, torch.complex64): torch.complex128, (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, (torch.float64, torch.complex64): torch.complex128, (torch.float64, torch.complex128): torch.complex128, } def _two_arg(f): @_wraps(f) def _f(x1, x2, /, **kwargs): x1, x2 = _fix_promotion(x1, x2) return f(x1, x2, **kwargs) if _f.__doc__ is None: _f.__doc__ = f"""\ Array API compatibility wrapper for torch.{f.__name__}. See the corresponding PyTorch documentation and/or the array API specification for more details. """ return _f def _fix_promotion(x1, x2, only_scalar=True): if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor): return x1, x2 if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: return x1, x2 # If an argument is 0-D pytorch downcasts the other argument if not only_scalar or x1.shape == (): dtype = result_type(x1, x2) x2 = x2.to(dtype) if not only_scalar or x2.shape == (): dtype = result_type(x1, x2) x1 = x1.to(dtype) return x1, x2 _py_scalars = (bool, int, float, complex) def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: num = len(arrays_and_dtypes) if num == 0: raise ValueError("At least one array or dtype must be provided") elif num == 1: x = arrays_and_dtypes[0] if isinstance(x, torch.dtype): return x return x.dtype if num == 2: x, y = arrays_and_dtypes return _result_type(x, y) else: # sort scalars so that they are treated last scalars, others = [], [] for x in arrays_and_dtypes: if isinstance(x, _py_scalars): scalars.append(x) else: others.append(x) if not others: raise ValueError("At least one array or dtype must be provided") # combine left-to-right return _reduce(_result_type, others + scalars) def _result_type(x, y): if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): xdt = x.dtype if not isinstance(x, torch.dtype) else x ydt = y.dtype if not isinstance(y, torch.dtype) else y if (xdt, ydt) in _promotion_table: return _promotion_table[xdt, ydt] # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow # cross-kind promotion. x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) # Basic renames bitwise_invert = torch.bitwise_not newaxis = None # torch.conj sets the conjugation bit, which breaks conversion to other # libraries. See https://github.com/data-apis/array-api-compat/issues/173 conj = torch.conj_physical # Two-arg elementwise functions # These require a wrapper to do the correct type promotion on 0-D tensors add = _two_arg(torch.add) atan2 = _two_arg(torch.atan2) bitwise_and = _two_arg(torch.bitwise_and) bitwise_left_shift = _two_arg(torch.bitwise_left_shift) bitwise_or = _two_arg(torch.bitwise_or) bitwise_right_shift = _two_arg(torch.bitwise_right_shift) bitwise_xor = _two_arg(torch.bitwise_xor) copysign = _two_arg(torch.copysign) divide = _two_arg(torch.divide) # Also a rename. torch.equal does not broadcast equal = _two_arg(torch.eq) floor_divide = _two_arg(torch.floor_divide) greater = _two_arg(torch.greater) greater_equal = _two_arg(torch.greater_equal) hypot = _two_arg(torch.hypot) less = _two_arg(torch.less) less_equal = _two_arg(torch.less_equal) logaddexp = _two_arg(torch.logaddexp) # logical functions are not included here because they only accept bool in the # spec, so type promotion is irrelevant. maximum = _two_arg(torch.maximum) minimum = _two_arg(torch.minimum) multiply = _two_arg(torch.multiply) not_equal = _two_arg(torch.not_equal) pow = _two_arg(torch.pow) remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) clip = get_xp(torch)(_aliases.clip) unstack = get_xp(torch)(_aliases.unstack) cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): axes = [] if ndim == 0 and axis: # Better error message in this case raise IndexError(f"Dimension out of range: {axis[0]}") lower, upper = -ndim, ndim - 1 for a in axis: if a < lower or a > upper: # Match torch error message (e.g., from sum()) raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}") if a < 0: a = a + ndim if a in axes: # Use IndexError instead of RuntimeError, and "axis" instead of "dim" raise IndexError(f"Axis {a} appears multiple times in the list of axes") axes.append(a) return sorted(axes) def _axis_none_keepdims(x, ndim, keepdims): # Apply keepdims when axis=None # (https://github.com/pytorch/pytorch/issues/71209) # Note that this is only valid for the axis=None case. if keepdims: for i in range(ndim): x = torch.unsqueeze(x, 0) return x def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): # Some reductions don't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). axes = _normalize_axes(axis, x.ndim) for a in reversed(axes): x = torch.movedim(x, a, -1) x = torch.flatten(x, -len(axes)) out = f(x, -1, **kwargs) if keepdims: for a in axes: out = torch.unsqueeze(out, a) return out def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[Dtype] = None, keepdims: bool = False, **kwargs) -> array: x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic # below because it still needs to upcast. if axis == (): if dtype is None: # We can't upcast uint8 according to the spec because there is no # torch.uint64, so at least upcast to int64 which is what sum does # when axis=None. if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: return x.to(torch.int64) return x.clone() return x.to(dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[Dtype] = None, keepdims: bool = False, **kwargs) -> array: x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. # Make sure it upcasts. if axis == (): if dtype is None: # We can't upcast uint8 according to the spec because there is no # torch.uint64, so at least upcast to int64 which is what sum does # when axis=None. if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: return x.to(torch.int64) return x.clone() return x.to(dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> array: x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) # torch.any doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs) return res.to(torch.bool) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.any(x, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res.to(torch.bool) # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> array: x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) # torch.all doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs) return res.to(torch.bool) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.all(x, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res.to(torch.bool) # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.mean(x, **kwargs) res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, **kwargs) -> array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. if isinstance(correction, float): _correction = int(correction) if correction != _correction: raise NotImplementedError("float correction in torch std() is not yet supported") else: _correction = correction # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.zeros_like(x) if isinstance(axis, int): axis = (axis,) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs) res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, **kwargs) -> array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. # if isinstance(correction, float): # correction = int(correction) # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.zeros_like(x) if isinstance(axis, int): axis = (axis,) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs) res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs) # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 def concat(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: Optional[int] = 0, **kwargs) -> array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 return torch.concat(arrays, axis, **kwargs) # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: if isinstance(axis, int): axis = (axis,) for a in axis: if x.shape[a] != 1: raise ValueError("squeezed dimensions must be equal to 1") axes = _normalize_axes(axis, x.ndim) # Remove this once pytorch 1.14 is released with the above PR #89017. sequence = [a - i for i, a in enumerate(axes)] for a in sequence: x = torch.squeeze(x, a) return x # torch.broadcast_to uses size instead of shape def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: return torch.roll(x, shift, axis, **kwargs) def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) # torch uses `dim` instead of `axis` def diff( x: array, /, *, axis: int = -1, n: int = 1, prepend: Optional[array] = None, append: Optional[array] = None, ) -> array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) # torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> array: result = torch.count_nonzero(x, dim=axis) if keepdims: if axis is not None: return result.unsqueeze(axis) return _axis_none_keepdims(result, x.ndim, keepdims) else: return result def where(condition: array, x1: array, x2: array, /) -> array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) # torch.reshape doesn't have the copy keyword def reshape(x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs) -> array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) # torch.arange doesn't support returning empty arrays # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations # (https://github.com/pytorch/pytorch/issues/70914) def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs) -> array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: if dtype is None: if _builtin_all(isinstance(i, int) for i in [start, stop, step]): dtype = torch.int64 else: dtype = torch.float32 return torch.empty(0, dtype=dtype, device=device, **kwargs) return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: int = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs) -> array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) if abs(k) <= n_rows + n_cols: z.diagonal(k).fill_(1) return z # torch.linspace doesn't have the endpoint parameter def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True, **kwargs) -> array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[bool, int, float, complex], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs) -> array: if isinstance(shape, int): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) # ones, zeros, and empty do not accept shape as a keyword argument def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs) -> array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs) -> array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs) -> array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k def tril(x: array, /, *, k: int = 0) -> array: return torch.tril(x, k) def triu(x: array, /, *, k: int = 0) -> array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 def expand_dims(x: array, /, *, axis: int = 0) -> array: return torch.unsqueeze(x, axis) def astype( x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None, ) -> array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) def broadcast_arrays(*arrays: array) -> List[array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. from ..common._aliases import (UniqueAllResult, UniqueCountsResult, UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 def unique_all(x: array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies # on non-deterministic behavior of scatter()). raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)") # values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True) # # torch.unique incorrectly gives a 0 count for nan values. # # https://github.com/pytorch/pytorch/issues/94106 # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) def unique_counts(x: array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. # https://github.com/pytorch/pytorch/issues/94106 counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) def unique_inverse(x: array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) def unique_values(x: array) -> array: return torch.unique(x) def matmul(x1: array, x2: array, /, **kwargs) -> array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) _vecdot = get_xp(torch)(_aliases.vecdot) def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.tensordot(x1, x2, dims=axes, **kwargs) def isdtype( dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], *, _tuple=True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. Note that outside of this function, this compat library does not yet fully support complex numbers. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html for more details """ if isinstance(kind, tuple) and _tuple: return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) elif isinstance(kind, str): if kind == 'bool': return dtype == torch.bool elif kind == 'signed integer': return dtype in _int_dtypes and dtype.is_signed elif kind == 'unsigned integer': return dtype in _int_dtypes and not dtype.is_signed elif kind == 'integral': return dtype in _int_dtypes elif kind == 'real floating': return dtype.is_floating_point elif kind == 'complex floating': return dtype.is_complex elif kind == 'numeric': return isdtype(dtype, ('integral', 'real floating', 'complex floating')) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") else: return dtype == kind def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 return torch.index_select(x, axis, indices, **kwargs) def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: return torch.take_along_dim(x, indices, dim=axis) def sign(x: array, /) -> array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: out = x/torch.abs(x) # sign(0) = 0 but the above formula would give nan out[x == 0+0j] = 0+0j return out else: out = torch.sign(x) if x.dtype.is_floating_point: out[torch.isnan(x)] = torch.nan return out __all__ = ['__array_namespace_info__', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', 'diff', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', 'take', 'take_along_axis', 'sign'] _all_ignore = ['torch', 'get_xp'] array-api-compat-1.11.2/array_api_compat/torch/_info.py000066400000000000000000000262311476700770300230560ustar00rootroot00000000000000""" Array API Inspection namespace This is the namespace for inspection functions as defined by the array API standard. See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. """ import torch from functools import cache class __array_namespace_info__: """ Get the array API inspection namespace for PyTorch. The array API inspection namespace defines the following functions: - capabilities() - default_device() - default_dtypes() - dtypes() - devices() See https://data-apis.org/array-api/latest/API_specification/inspection.html for more details. Returns ------- info : ModuleType The array API inspection namespace for PyTorch. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, 'integral': numpy.int64, 'indexing': numpy.int64} """ __module__ = 'torch' def capabilities(self): """ Return a dictionary of array API library capabilities. The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library supports boolean indexing. Always ``True`` for PyTorch. - **"data-dependent shapes"**: boolean indicating whether an array library supports data-dependent output shapes. Always ``True`` for PyTorch. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html for more details. See Also -------- __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- capabilities : dict A dictionary of array API library capabilities. Examples -------- >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, 'data-dependent shapes': True} """ return { "boolean indexing": True, "data-dependent shapes": True, # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } def default_device(self): """ The default device used for new PyTorch arrays. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes, __array_namespace_info__.devices Returns ------- device : str The default device used for new PyTorch arrays. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_device() 'cpu' """ return torch.device("cpu") def default_dtypes(self, *, device=None): """ The default data types used for new PyTorch arrays. Parameters ---------- device : str, optional The device to get the default data types for. For PyTorch, only ``'cpu'`` is allowed. Returns ------- dtypes : dict A dictionary describing the default data types used for new PyTorch arrays. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.dtypes, __array_namespace_info__.devices Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_dtypes() {'real floating': torch.float32, 'complex floating': torch.complex64, 'integral': torch.int64, 'indexing': torch.int64} """ # Note: if the default is set to float64, the devices like MPS that # don't support float64 will error. We still return the default_dtype # value here because this error doesn't represent a different default # per-device. default_floating = torch.get_default_dtype() default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 default_integral = torch.int64 return { "real floating": default_floating, "complex floating": default_complex, "integral": default_integral, "indexing": default_integral, } def _dtypes(self, kind): bool = torch.bool int8 = torch.int8 int16 = torch.int16 int32 = torch.int32 int64 = torch.int64 uint8 = torch.uint8 # uint16, uint32, and uint64 are present in newer versions of pytorch, # but they aren't generally supported by the array API functions, so # we omit them from this function. float32 = torch.float32 float64 = torch.float64 complex64 = torch.complex64 complex128 = torch.complex128 if kind is None: return { "bool": bool, "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, "float32": float32, "float64": float64, "complex64": complex64, "complex128": complex128, } if kind == "bool": return {"bool": bool} if kind == "signed integer": return { "int8": int8, "int16": int16, "int32": int32, "int64": int64, } if kind == "unsigned integer": return { "uint8": uint8, } if kind == "integral": return { "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, } if kind == "real floating": return { "float32": float32, "float64": float64, } if kind == "complex floating": return { "complex64": complex64, "complex128": complex128, } if kind == "numeric": return { "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, "float32": float32, "float64": float64, "complex64": complex64, "complex128": complex128, } if isinstance(kind, tuple): res = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") @cache def dtypes(self, *, device=None, kind=None): """ The array API data types supported by PyTorch. Note that this function only returns data types that are defined by the array API. Parameters ---------- device : str, optional The device to get the data types for. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. If a tuple, a dictionary containing the union of the given kinds is returned. The following kinds are supported: - ``'bool'``: boolean data types (i.e., ``bool``). - ``'signed integer'``: signed integer data types (i.e., ``int8``, ``int16``, ``int32``, ``int64``). - ``'unsigned integer'``: unsigned integer data types (i.e., ``uint8``, ``uint16``, ``uint32``, ``uint64``). - ``'integral'``: integer data types. Shorthand for ``('signed integer', 'unsigned integer')``. - ``'real floating'``: real-valued floating-point data types (i.e., ``float32``, ``float64``). - ``'complex floating'``: complex floating-point data types (i.e., ``complex64``, ``complex128``). - ``'numeric'``: numeric data types. Shorthand for ``('integral', 'real floating', 'complex floating')``. Returns ------- dtypes : dict A dictionary mapping the names of data types to the corresponding PyTorch data types. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.devices Examples -------- >>> info = np.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, 'int32': numpy.int32, 'int64': numpy.int64} """ res = self._dtypes(kind) for k, v in res.copy().items(): try: torch.empty((0,), dtype=v, device=device) except: del res[k] return res @cache def devices(self): """ The devices supported by PyTorch. Returns ------- devices : list of str The devices supported by PyTorch. See Also -------- __array_namespace_info__.capabilities, __array_namespace_info__.default_device, __array_namespace_info__.default_dtypes, __array_namespace_info__.dtypes Examples -------- >>> info = np.__array_namespace_info__() >>> info.devices() [device(type='cpu'), device(type='mps', index=0), device(type='meta')] """ # Torch doesn't have a straightforward way to get the list of all # currently supported devices. To do this, we first parse the error # message of torch.device to get the list of all possible types of # device: try: torch.device('notadevice') except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" devices_names = e.args[0].split('Expected one of ')[1].split(' device type')[0].split(', ') # Next we need to check for different indices for different devices. # device(device_name, index=index) doesn't actually check if the # device name or index is valid. We have to try to create a tensor # with it (which is why this function is cached). devices = [] for device_name in devices_names: i = 0 while True: try: a = torch.empty((0,), device=torch.device(device_name, index=i)) if a.device in devices: break devices.append(a.device) except: break i += 1 return devices array-api-compat-1.11.2/array_api_compat/torch/fft.py000066400000000000000000000034021476700770300225360ustar00rootroot00000000000000from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: import torch array = torch.Tensor from typing import Union, Sequence, Literal from torch.fft import * # noqa: F403 import torch.fft # Several torch fft functions do not map axes to dim def fftn( x: array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( x: array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( x: array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( x: array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( x: array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, ) -> array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( x: array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, ) -> array: return torch.fft.ifftshift(x, dim=axes, **kwargs) __all__ = torch.fft.__all__ + [ "fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift", ] _all_ignore = ['torch'] array-api-compat-1.11.2/array_api_compat/torch/linalg.py000066400000000000000000000112421476700770300232260ustar00rootroot00000000000000from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: import torch array = torch.Tensor from torch import dtype as Dtype from typing import Optional, Union, Tuple, Literal inf = float('inf') from ._aliases import _fix_promotion, sum from torch.linalg import * # noqa: F403 # torch.linalg doesn't define __all__ # from torch.linalg import __all__ as linalg_all from torch import linalg as torch_linalg linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] # outer is implemented in torch but aren't in the linalg namespace from torch import outer # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") # torch.linalg.vecdot doesn't support integer dtypes if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): if kwargs: raise RuntimeError("vecdot kwargs not supported for integral dtypes") x1_ = torch.moveaxis(x1, axis, -1) x2_ = torch.moveaxis(x2, axis, -1) x1_, x2_ = torch.broadcast_tensors(x1_, x2_) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) def solve(x1: array, x2: array, /, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever # 1. x1.ndim - 1 == x2.ndim # 2. x1.shape[:-1] == x2.shape # # See linalg_solve_is_vector_rhs in # aten/src/ATen/native/LinearAlgebraUtils.h and # TORCH_META_FUNC(_linalg_solve_ex) in # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code. # # The easiest way to work around this is to prepend a size 1 dimension to # x2, since x2 is already one dimension less than x1. # # See https://github.com/pytorch/pytorch/issues/52915 if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape: x2 = x2[None] return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Union[int, float, Literal[inf, -inf]] = 2, **kwargs, ) -> array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') if out is None: dtype = None if x.dtype == torch.complex64: dtype = torch.float32 elif x.dtype == torch.complex128: dtype = torch.float64 out = torch.zeros_like(x, dtype=dtype) # The norm of a single scalar works out to abs(x) in every case except # for ord=0, which is x != 0. if ord == 0: out[:] = (x != 0) else: out[:] = torch.abs(x) return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) __all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] _all_ignore = ['torch_linalg', 'sum'] del linalg_all array-api-compat-1.11.2/cupy-xfails.txt000066400000000000000000000354311476700770300177700ustar00rootroot00000000000000# cupy doesn't have __index__ (and we cannot wrap the ndarray object) array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint8)] array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)] array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)] array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)] array_api_tests/test_array_object.py::test_scalar_casting[__index__(int8)] array_api_tests/test_array_object.py::test_scalar_casting[__index__(int16)] array_api_tests/test_array_object.py::test_scalar_casting[__index__(int32)] array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)] # testsuite bug (https://github.com/data-apis/array-api-tests/issues/172) array_api_tests/test_array_object.py::test_getitem # copy=False is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays # finfo test is testing that the result is a float instead of float32 (see # also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-__index__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] array_api_tests/test_linalg.py::test_solve # We cannot modify array methods array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] # floating point inaccuracy array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 array_api_tests/test_searching_functions.py::test_argmax array_api_tests/test_searching_functions.py::test_argmin array_api_tests/test_statistical_functions.py::test_min array_api_tests/test_statistical_functions.py::test_max # prod() sometimes doesn't give nan for 0*overflow array_api_tests/test_statistical_functions.py::test_prod # testsuite incorrectly thinks meshgrid doesn't have indexing argument # (https://github.com/data-apis/array-api-tests/issues/171) array_api_tests/test_signatures.py::test_func_signature[meshgrid] # We cannot add array attributes array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] array_api_tests/test_signatures.py::test_array_method_signature[__index__] array_api_tests/test_signatures.py::test_array_method_signature[to_device] # We do not attempt to workaround special cases (and the operator method ones array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is +0 and x2_i is -0) -> +0] array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is -0 and x2_i is +0) -> +0] array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is -0 and x2_i is -0) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i < 0 and x2_i is -0) -> NaN] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i > 0 and x2_i is -0) -> NaN] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] array_api_tests/test_special_cases.py::test_binary[__pow__(x2_i is -0) -> 1] array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[add(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] array_api_tests/test_special_cases.py::test_binary[add(x1_i is +0 and x2_i is -0) -> +0] array_api_tests/test_special_cases.py::test_binary[add(x1_i is -0 and x2_i is +0) -> +0] array_api_tests/test_special_cases.py::test_binary[add(x1_i is -0 and x2_i is -0) -> -0] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is +0 and x2_i is -0) -> roughly +pi] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i < 0) -> roughly -pi] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i is +0) -> -0] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i is -0) -> roughly -pi] array_api_tests/test_special_cases.py::test_binary[divide(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[divide(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] array_api_tests/test_special_cases.py::test_binary[pow(x2_i is -0) -> 1] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i < 0 and x2_i is -0) -> NaN] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i > 0 and x2_i is -0) -> NaN] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__iadd__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is +0 and x2_i is -0) -> +0] array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is +0) -> +0] array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i < 0 and x2_i is -0) -> NaN] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i > 0 and x2_i is -0) -> NaN] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] array_api_tests/test_special_cases.py::test_iop[__ipow__(x2_i is -0) -> 1] array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_unary[__abs__(x_i is -0) -> +0] array_api_tests/test_special_cases.py::test_unary[abs(x_i is -0) -> +0] array_api_tests/test_special_cases.py::test_unary[asin(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[asinh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[atan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[atanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[ceil(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[cos(x_i is -0) -> 1] array_api_tests/test_special_cases.py::test_unary[cosh(x_i is -0) -> 1] array_api_tests/test_special_cases.py::test_unary[exp(x_i is -0) -> 1] array_api_tests/test_special_cases.py::test_unary[expm1(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[floor(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[log1p(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[round(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[sin(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[signbit(x_i is -0) -> True] array_api_tests/test_special_cases.py::test_unary[sinh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] # CuPy gives the wrong shape for n-dim fft funcs. See # https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870 array_api_tests/test_fft.py::test_fftn array_api_tests/test_fft.py::test_ifftn array_api_tests/test_fft.py::test_rfftn # observed in the 1.10 release process, is likely related to xfails above array_api_tests/test_fft.py::test_irfftn # 2023.12 support # cupy.ndaray cannot be specified as `repeats` argument. array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +# 2024.12 support array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0] array-api-compat-1.11.2/dask-skips.txt000066400000000000000000000005141476700770300175670ustar00rootroot00000000000000# NOTE: dask tests run on a very small number of examples in CI due to # slowness. This causes very high flakiness in the tests. # Before changing this file, please run with at least 200 examples. # Passes, but extremely slow array_api_tests/test_linalg.py::test_outer # Hangs array_api_tests/test_creation_functions.py::test_eye array-api-compat-1.11.2/dask-xfails.txt000066400000000000000000000177751476700770300177450ustar00rootroot00000000000000# NOTE: dask tests run on a very small number of examples in CI due to # slowness. This causes very high flakiness in the tests. # Before changing this file, please run with at least 200 examples. # Broken edge case with shape 0 # https://github.com/dask/dask/issues/11800 array_api_tests/test_array_object.py::test_setitem # Various indexing errors array_api_tests/test_array_object.py::test_getitem_masking # zero division error, and typeerror: tuple indices must be integers or slices not tuple array_api_tests/test_creation_functions.py::test_eye # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) array_api_tests/test_creation_functions.py::test_linspace # Shape mismatch array_api_tests/test_indexing_functions.py::test_take # Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] # Fails because shape is NaN since we don't materialize it yet array_api_tests/test_searching_functions.py::test_nonzero array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_counts # Different error but same cause as above, we're just trying to do ndindex on nan shape array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_set_functions.py::test_unique_values # Linalg failures (signature failures/missing methods) # fails for ndim > 2 array_api_tests/test_linalg.py::test_svdvals # dtype mismatch got uint64, but should be uint8; NPY_PROMOTION_STATE=weak doesn't help array_api_tests/test_linalg.py::test_tensordot # AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] array_api_tests/test_linalg.py::test_linalg_tensordot # ZeroDivisionError in dask's normalize_chunks/auto_chunks internals array_api_tests/test_linalg.py::test_inv array_api_tests/test_linalg.py::test_matrix_power # Linalg - these don't exist in dask array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] array_api_tests/test_linalg.py::test_cross array_api_tests/test_linalg.py::test_det array_api_tests/test_linalg.py::test_eigh array_api_tests/test_linalg.py::test_eigvalsh array_api_tests/test_linalg.py::test_matrix_rank array_api_tests/test_linalg.py::test_pinv array_api_tests/test_linalg.py::test_slogdet array_api_tests/test_has_names.py::test_has_names[linalg-cross] array_api_tests/test_has_names.py::test_has_names[linalg-det] array_api_tests/test_has_names.py::test_has_names[linalg-eigh] array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh] array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] array_api_tests/test_has_names.py::test_has_names[linalg-pinv] array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] # Constructing the input arrays fails to a weird shape error... array_api_tests/test_linalg.py::test_solve # missing full_matrices kw # https://github.com/dask/dask/issues/10389 # also only supports 2-d inputs array_api_tests/test_linalg.py::test_svd # Missing dlpack stuff array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] array_api_tests/test_signatures.py::test_array_method_signature[to_device] array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] # No mT on dask array array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # Edge case of args near 2**63 # https://github.com/dask/dask/issues/11706 array_api_tests/test_creation_functions.py::test_arange # da.searchsorted with a sorter argument is not supported array_api_tests/test_searching_functions.py::test_searchsorted # 2023.12 support array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[1] array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[None] array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1] array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None] array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis] array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[take_along_axis] array_api_tests/test_linalg.py::test_cholesky array_api_tests/test_linalg.py::test_linalg_matmul array_api_tests/test_linalg.py::test_matmul array_api_tests/test_linalg.py::test_matrix_norm array_api_tests/test_linalg.py::test_qr array_api_tests/test_manipulation_functions.py::test_roll # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array-api-compat-1.11.2/docs/000077500000000000000000000000001476700770300157055ustar00rootroot00000000000000array-api-compat-1.11.2/docs/Makefile000066400000000000000000000013611476700770300173460ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) livehtml: sphinx-autobuild --open-browser --watch .. --port 0 -b html $(SOURCEDIR) $(ALLSPHINXOPTS) $(BUILDDIR)/html array-api-compat-1.11.2/docs/_static/000077500000000000000000000000001476700770300173335ustar00rootroot00000000000000array-api-compat-1.11.2/docs/_static/custom.css000066400000000000000000000013071476700770300213600ustar00rootroot00000000000000/* Makes the text look better on Mac retina displays (the Furo CSS disables*/ /* subpixel antialiasing). */ body { -webkit-font-smoothing: auto; -moz-osx-font-smoothing: auto; } /* Disable the fancy scrolling behavior when jumping to headers (this is too slow for long pages) */ html { scroll-behavior: auto; } /* Make checkboxes from the tasklist extension ('- [ ]' in Markdown) not add bullet points to the checkboxes. This can be removed once https://github.com/executablebooks/mdit-py-plugins/issues/59 is addressed. */ .contains-task-list { list-style: none; } /* Make the checkboxes indented like they are bullets */ .task-list-item-checkbox { margin: 0 0.2em 0.25em -1.4em; } array-api-compat-1.11.2/docs/_static/favicon.png000066400000000000000000000120401476700770300214630ustar00rootroot00000000000000‰PNG  IHDRÄÃÝ£¾ÓtEXtSoftwareAdobe ImageReadyqÉe<ÂIDATxÚì|ŇO ˆ´I¥)Ò;„*‚T# ½7#„NõñP”z@èàñ@BQº ÒKèU$„Ð!yç̽#胔½»{ÿßï·3k¹ÙÝ3óíÎìḭ̂cLu4+‡´æ´8/#©Ù¨»(Þµ¤Ýy¹Ã±^ !ôÄòÁ’΢¬.­8OÍKÝŠÁùLj1•7ùcíñîÉy^b9Ö8ïʱ¾ !lMÐ Oz'ó^ËK±±o=ÿ÷ŽŽÑœ^¥›‘Ó8 –cP™“gR•þÌ"BNŽu†?Åú §ç9Öý8Λ „í h Žëßþ#9b]—ã<þ%^Žu$Ç9H5¥ gc ±l`aÊ–e6¯•äJ÷Z¿±ˆq™nÜžÎùwÔj,*úÿsN‡p¬‹pœ³¼fœqz–ãìÍ1Þ!Rš%¼)»ëz›5Qûïä(íÜ£t=â[j3.µþ•1®Íé`ŽsQγp¬ç0Žñ8^Ïq†)PHïª&Ov×ê\@“ø×â¸ÀnsçõÑ#EH(E §‡9ÎÞãS"¹Xì×”r¸Mâµl\HNÉöwã8 ká1Úúo±KûätǸ>ç®Iáå“Ï Žñ¿9¾s!DRXØ?-§3(—»R–ÛŽEŒpºvŒä–bûñ¡v!ÂÂþ”¹Ü?áÜcœ:c|‡ã»ƒ×ºq|oBˆ7/¬Z\Prg£`ŠÔËbüÁ'WŠ1¦CKÆ÷™5¾}9¶« Äë°ÀW҉䑵µµã¬=NŽ9 £Ë·¤o1–:L8c ø¦QW\¬ÒWp×L„—ãÁ±]Ëk=8¶ Äß1¿_Yz?Û,^+BOŸ½móýqv1nÒÅ"Æ8ê8ñ¬!E˜ßOXàØv yºüôYjÄöÇÖ㺠B$,0I{s àÂzOwê…s¬b<2 Y9¶éu[‰ë^ëÏq…Ø7å}W® å¸À\t]Áœä6â :ÿÇB%FçIu¹Ÿ}Ó¨ –÷Ýö$wæô(B¸J³éÇÕ‡czÐ~…˜Û§å{oX.2Ò3g§h%ƹß-bt™üDû5·¤SyߨÇ5ƒ¡švÎN×8¦ò ÁtŽ© 1·§Ó¸àjq¡¹¶“úBŒ©¶,Ä?‰àÃy>ƒÇô§¿pL}8žçÍ/ÄìÞ^T û^{ŸžZäׯ"F*.Ep‘ÛDßR;ÿ­$3ú½h˜]7æ™ö…>,EÔ@]‰0•Eé,O³†{xbì¦G‰ Ô¹}3õ‹‘Õà"lV"tŸay{ïIm·ÿU›*\¦2Õfyº÷ð}½†I¿Øï=”ÏiU¤ÚeJü©)5о^ªÍöwŸ¸/ *ê1óˆcº\1†°FA¾m±ˆc¿a÷ m·?¼µåUšÚejr™Ô{¸ô?£ƒEŒJT§¬Lr&³Ö­àe¾fbì:.bìá uˆó,F Î;è^Œ¿Š°ë¸mD¨SVʬ8—£!ÎÆ™âÄ"Fm°¼£Ô”›R+Tà›eZŠL½fæ|%MíÖœóöºÃÝe‹¡×,á¼æ" k%MÝÎ\NÕ8/aâ±ým×Ä’>t¶Ðú_Vª ðm¶‘û¨¸t¶ËÒä®r»¶ÍÅpw !™y¼O€ Ûâ«…œ}—Ò„ÿj+†£C¨jJîô±êcȸ ¢µ/ƒç­ç|ÿ¿ÚNòåû¹«Ú—&U¨Îrôýtd‡8ÛãAsa“j7¯ZžóÜ¿°ˆ1qõ)M¶¿I +ÿ™ä!Ò|“{ùWT³jÓ{šÆ¢ŸÜ,¢f À\™îÚ§ö-D‘Îr-ºs¯˜©DpU"¬¥o–IŒ·Óê=WSrsÚ}A¨q¥œ$·4Í)ÆUUXZŠav,2„°KµA{!âiR9ÉC°¡-åƒ&^&Cž6o£¹Áßs¾Ž6$ÿ6ºÊ+FäI]jˉ塸 Eاú #ƒÖѪÝg´Ü´í¾1׬j.ufܼ KQÁTš)½ˆB³7®åüš½)éÓ»®¤É»žœHjÓ˜&a¿ê'Œ^!3†¡•;5ßÛt±YÕ,Em醛TŒå,†\1ö%Z ïºX„NÁ„šËøt*IƒšÖá¼!ED—7™2Îs•í9%Q¿?0­/‹0Ñ|"d8Çi ùÌ*x˜Vüdó]ÒÇ»L–@&G©8Áäפ®êxߎþÐ'Fæ’êÀË”DþÞÕT"dÉp^Ý|·j£êsé@} OÐ’².›X ¹§Þ‚ÅÈk-’ðÛ§¦Áß*‚¥¼u…>ßvÃÑAÄø‘|?—[+^,FFáÉ]· ÿ•Ûn[õ(‚¾…ˆg©L0Átp×±7R¿Æró½Ä0ŒrO}M\-3;„XËS×c<Ä5'@ˆuY1 !Â,‚<¡\Ååg˜&Ÿ±Å‹á袮}ʇ7ê³Q u#šüýzuâZ²ÝpO]9bn‘´ (˜ûòDóSêé%WŒz1lAæôwÕ jêÚõêʰh«a_?0öÒ!ø¥ª †-D°\±§­[¨ÊÀR†ÆcªãÅpp3”õøTÄhÌb`Ìxʉ°•¦ÿ wÖqü/˜åÐÌUaæo‘¹èsSJ&ùÚA> š¼HƒJœ|"l£™ëå9ÂZŽ÷y³¢9Ï Á"FYîJ}ÆbÔƒI–á…ÁçÍz˜ænRX n¢õŠÑ¾¬/¯„Ô ;ãµ±¼‡µš6q<Ï™ýpí£=g³äxîc|¯úÞõê²ÕQÛÿQùújðó7uçl¶‹Ã¶¯NçìM"Æ8uÆó®÷™õuêÄóŒE8ž@„äË! †£ã@ÕnRyY¦±òº¹›åãê{íMûB° ïÜËBCˆç\á¸ìµç8¢P:„à9©ì=„B!€@ „B!€@ „B!€@ B!€@ „B!€@ „B!€@ €@ Ä߆"g§Gðày,B{çâ=\ðÉKœŠ©bpñÆj»?'ؽ'þD |ËñZ'^Z’GÖ✻ӳXs7œEþ(^.Ð¥›38ŸÇ±ÀEí¦ÃRB–…ýkq>ˆr¹ã< ‹ádBÂéJØ1Ηñh=~!^Aûñ!œ†XÅèÆb”çÜÅHmpžqzÛ*Âëqñ†b,ö+¨®9Ü>á<3‹‘Æ`"<æ4‚®…ÿÌù,>®-(\‘xÚúŸ–F-PH]1²»6²6¥Òê\„ûœFÓõ`4Ç)&„H>ÚŒ“ Õ‹ÅÄù`£ƒ.Űˆp›EX ú–ý"ÅÄ 7Œ– ÃùÊ–Å‹sŠÍ`Óýrt¼Gr×èÛ"Â(ÞÏ{(,¡­ÆJ…JAƒ†rÞ•ÞÉܓ󜼤g94!šÓ«t3rç3x¿bP8Âv´” °J ¾†°…9Ï”bb8ª[§wy¹Î"LSÛ·ì€:C–åƒë«æTV#ůɳŒ¤Èaù­£ƒåaÚ­;'ÕÕ@úB螣7pºViÀyrË(OÂ'þŠàðD‰~W:È#ùï¯GqY9¤/½“ðûZV¹ ™ùŸÄcÚÔžÝIEND®B`‚array-api-compat-1.11.2/docs/changelog.md000066400000000000000000000316641476700770300201700ustar00rootroot00000000000000# Changelog ## 1.11.2 (2025-03-20) This is a bugfix release with no new features compared to version 1.11. - fix the `result_type` wrapper for pytorch. Previously, `result_type` had multiple issues with scalar arguments. - fix several issues with `clip` wrappers. Previously, `clip` was failing to allow behaviors which are unspecified by the 2024.12 standard but allowed by the array libraries. The following users contributed to this release: Evgeni Burovski Guido Imperiale Magnus Dalen KvalevÃ¥g ## 1.11.1 (2025-03-04) This is a bugfix release with no new features compared to version 1.11. ### Major Changes - fix `count_nonzero` wrappers: work around the lack of the `keepdims` argument in several array libraries (torch, dask, cupy); work around numpy returning python ints in for some input combinations. ### Minor Changes - runnings self-tests does not require all array libraries. Missing libraries are skipped. The following users contributed to this release: Evgeni Burovski Guido Imperiale ## 1.11.0 (2025-02-27) ### Major Changes This release targets the 2024.12 Array API revision. This includes - `__array_api_version__` for the wrapped APIs is now set to `2024.12`; - Wrappers for `count_nonzero`; - Wrappers for `cumulative_prod`; - Wrappers for `take_along_axis` (with the exception of Dask); - Wrappers for `diff`; - `__capabilities__` dict contains a `max_dimensions` key; - Python scalars are accepted as arguments to `result_type`; - `fft.fftfreq` and `fft.rfftfreq` functions now accept an optional `dtype` argument to control the output data type. Note that these wrappers, as well as other 2024.12 features, are relatively undertested in this release, and may have rough edges. Please report any issues you encounter in [the issue tracker](https://github.com/data-apis/array-api-compat/issues). New functions to test properties of arrays: - `is_writeable_array` (benefits NumPy, JAX, Sparse) - `is_lazy_array` (benefits JAX, Dask, ndonnx) Improved support for JAX: - Work arounds for `.device` attribute and `to_device` function not working correctly within `jax.jit` ### Minor Changes - Several improvements to `dask.array` wrappers: - `size` returns None for arrays of unknown shapes. - `astype(..., copy=True)` always copies, independently of the Dask version. - implementations of `sort` and `argsort` are now available. Note that these implementations are relatively crude, and might be memory intensive. - `asarray` no longer accidentally materializes the Dask graph - `torch` wrappers contain unsigned integer dtypes of widths >8 bits, `uint16`, `uint32` and `uint64` if PyTorch version is at least 2.3. Note that the unsigned integer support is incomplete in PyTorch itself, see [gh-253](https://github.com/data-apis/array-api-compat/pull/253). ### Authors The following users contributed to this release: Athan Reines Guido Imperiale Evgeni Burovski Guido Imperiale Lucas Colley Ralf Gommers Thomas Li ## 1.10.0 (2024-12-25) ### Major Changes - New function `is_writeable_array` adds transparent support for readonly arrays, such as JAX arrays or numpy arrays with `.flags.writeable=False`. - `asarray(..., copy=None)` with `dask` backend always copies, so that `copy=None` and `copy=True` are equivalent for the `dask` backend. This change is made to be forward compatible with the `dask==2024.12` release. ### Minor Changes - `array_namespace` accepts (and ignores) `None` and python scalars (int, float, complex, bool). This change is to simplify downstream adoption, for functions where arguments can be either arrays or scalars. - `vecdot` conjugates its first argument, as stipulated by the Array API spec. Previously, conjation if the first argument was missing. ## 1.9.1 (2024-10-29) ### Major Changes - `__array_api_version__` for the wrapped APIs is now set to `2023.12`. ### Minor Changes - Wrap `sign` so that it always uses the standard definition for complex numbers, and always propagates nans. - Wrap dask.array.fft. - Readd `python_requires` to the package metadata. ## 1.9 (2024-10-??) ### Major Changes - New helper functions to determine if a namespace is from a given library ({func}`~.is_numpy_namespace`, {func}`~.is_torch_namespace`, etc.). - More support for the [2023.12 version of the standard](https://data-apis.org/array-api/latest/changelog.html#v2023-12). This includes - Wrappers for `cumulative_sum()`. - Wrappers for `unstack()`. - Update floating-point type promotion in `sum()`, `prod()`, and `trace()` to be inline with the 2023.12 specification (32-bit types no longer promote to 64-bit when `dtype=None`). - Add the [inspection APIs](https://data-apis.org/array-api/latest/API_specification/inspection.html) to the wrapped namespaces. These can be accessed with `xp.__array_namespace_info__()`. - Various fixes to the `clip()` wrappers. - `torch.conj` now wrapps `torch.conj_physical`, which makes a copy rather than setting the conjugation bit, as arrays with the conjugation bit set do not support some APIs. - `torch.sign` is now wrapped to support complex numbers and propogate nans properly. ### Minor Changes - NumPy 2.0 is now wrapped again. Previously it was unwrapped because it has full 2022.12 array API support but it now requires wrapping again for 2023.12 support. - Support for JAX 0.4.32 and newer which implements the array API directly in `jax.numpy`. - `hypot`, `minimum`, and `maximum` (new in 2023.12) are wrapped in PyTorch to support proper scalar type promotion. ## 1.8 (2024-07-24) ### Major Changes - Add support for [ndonnx](https://github.com/Quantco/ndonnx). Array API support itself lives in the ndonnx library, but this adds the {func}`~.is_ndonnx_array` helper function. ([@adityagoel4512](https://github.com/adityagoel4512)). - Partial support for the [2023.12 version of the standard](https://data-apis.org/array-api/latest/changelog.html#v2023-12). This includes - Wrappers for `clip()`. - torch wrapper for `copysign()` with correct type promotion. Note that many of the new functions in the 2023.12 version of the standard are already fully implemented in upstream libraries and will already work. ## 1.7.1 (2024-05-28) ### Minor Changes - Fix a typo in setup.py ([@sunpoet](https://github.com/sunpoet)). ## 1.7 (2024-05-24) ### Major Changes - Add support for `sparse`. Note that unlike other array libraries, array-api-compat does not contain any wrappers for `sparse` functions. All `sparse` array API support is in `sparse` itself. Thus, there is no `array_api_compat.sparse` submodule, and `array_namespace()` returns the `sparse` module. - Added the function `is_pydata_sparse_array(x)`. ### Minor Changes - Fix JAX `float0` arrays. See https://github.com/google/jax/issues/20620. ([@NeilGirdhar](https://github.com/NeilGirdhar)) - Fix `torch.linalg.vector_norm()` when `axis=()`. - Fix `torch.linalg.solve()` to apply the array API standard rules for when `x2` should be treated as a vector vs. a matrix. - Fix PyTorch test failures on CI by skipping uint16, uint32, uint64 tests. ## 1.6 (2024-03-29) ### Major Changes - Drop support for Python 3.8. - NumPy 2.0 is now left completely unwrapped. - New flag `use_compat` to {func}`~.array_namespace` to force the use or non-use of the compat wrapper namespace. The default is to return a compat namespace when it is appropiate. - Fix the `copy` flag to `asarray` for NumPy, CuPy, and Dask. - Fix the `device` flag to `asarray` for CuPy. - Fix various issues with `asarray` for Dask. ### Minor Changes - Test Python 3.12 on CI. - Add more tests for {func}`~.array_namespace`. - Add more tests for `asarray`. - Add a test that there are no hard dependencies. ## 1.5.1 (2024-03-20) ### Minor Changes - Add [HTML documentation](https://data-apis.org/array-api-compat/). Includes new documentation on the [scope of the package](scope) and new [developer documentation](dev/index.md). - Fix `array_api_compat.numpy.asarray(torch.Tensor)` to return a NumPy array. - Allow Python scalars in torch functions. - Fix the `torch.std` wrapper when correction is an `int`. - Fix issues with `qr` and `svd` in the Dask wrappers. ## 1.5 (2024-03-07) ### Major Changes - Add support for Dask ([@lithomas1](https://github.com/lithomas1)). - Add support for JAX. Note that unlike other array libraries, array-api-compat does not contain any wrappers for JAX functions. All JAX array API support is in JAX itself. Thus, there is no `array_api_compat.jax` submodule, and `array_namespace()` returns the `jax.experimental.array_api` module. - The functions `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`, `is_dask_array(x)`, `is_jax_array(x)` are now part of the public `array_api_compat` API. - Add wrappers for the `fft` extension module for NumPy, CuPy, and PyTorch. ### Minor Changes - Allow `'2022.12'` as the `api_version` in {func}`~.array_namespace()`. `'2021.12'` is also supported but will issue a warning since the returned namespace will still be a 2022.12 compliant one. - Add wrapper for numpy.linalg.solve, which broadcasts the inputs according to the standard. - Add wrappers for various PyTorch linalg functions. - Fix a bug with `numpy.linalg.vector_norm(keepdims=True)`. - BREAKING: Update `vecdot` wrappers to apply `axes` before broadcasting, not after. This matches the updated 2023.12 standard wording, and also the behavior of the new `numpy.vecdot` gufunc in NumPy 2.0. - Fix some linalg functions which were supposed to be in both the main namespace and the linalg extension namespace. - Add Ruff to CI. ([@adonath](https://github.com/adonath)) - Test that internal definitions of `__all__` are self-consistent, which should help to avoid issues where wrappers are accidentally not exported to the compat namespaces properly. ## 1.4.1 (2024-01-18) ### Minor Changes - Add support for the upcoming NumPy 2.0 release. - Added a torch wrapper for `trace` (`torch.trace` doesn't support the `offset` argument or stacking) - Wrap numpy, cupy, and torch `nonzero` to raise an error for zero-dimensional input arrays. - Add torch wrapper for `newaxis`. - Improve error message for `array_namespace` - Fix linalg.cholesky returning the conjugate of the expected upper decomposition for numpy and cupy. ## 1.4 (2023-09-13) ### Major Changes - Releases are now made with GitHub Actions (thanks [@matthewfeickert](https://github.com/matthewfeickert)). ### Minor Changes - Fix `torch.result_type()` cross-kind promotion ([@lucascolley](https://github.com/lucascolley)). - Fix the torch.take() wrapper to make axis optional for ndim = 1. - Add requires-python metadata to the package ([@matthewfeickert](https://github.com/matthewfeickert)). ## 1.3 (2023-06-20) ### Major Changes - Add [2022.12](https://data-apis.org/array-api/2022.12/) standard support. This includes things like adding complex dtype support, adding the new `take` function, and various minor changes in the specification. ### Minor Changes - Support `"cpu"` in CuPy `to_device()`. - Return a new array in NumPy/CuPy `reshape(copy=False)`. - Fix signatures for PyTorch `broadcast_to` and `permute_dims`. ## 1.2 (2023-04-03) ### Major Changes - Support the linalg extension in the `array_api_compat.torch` namespace. - Add `isdtype()`. ### Minor Changes - Fix the `k` keyword argument to `tril` and `triu` in `torch`. ## 1.1.1 (2023-03-10) ### Major Changes - Rename `get_namespace()` to `array_namespace()` (`get_namespace()` is maintained as a backwards compatible alias). ### Minor Changes - The minimum supported NumPy version is now 1.21. Fixed a few issues with NumPy 1.21 (with `unique_*` and `asarray`), although there are also a few known issues with this version (see the README). - Add `api_version` to `get_namespace()`. - `array_namespace()` (*née* `get_namespace()`) now works correctly with `torch` tensors. - `array_namespace()` (*née* `get_namespace()`) now works correctly with `numpy.array_api` arrays. - `array_namespace()` (*née* `get_namespace()`) now raises `TypeError` instead of `ValueError`. - Fix the `torch.std` wrapper. - Add `torch` wrappers for `ones`, `empty`, and `zeros` so that `shape` can be passed as a keyword argument. ## 1.1 (2023-02-24) ### Major Changes - Added support for PyTorch. - Add helper function `size()` (required if torch is used as `torch.Tensor.size` is a method that is incompatible with the array API [`.size`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html#array_api.array.size)). - All wrapper functions that wrap existing library functions now pass through arbitrary `**kwargs`. ### Minor Changes - Added CI to run against the [array API testsuite](https://github.com/data-apis/array-api-tests). - Fix `sort(stable=False)` and `argsort(stable=False)` with CuPy. ## 1.0 (2022-12-05) ### Major Changes - Initial release. Includes support for NumPy and CuPy. array-api-compat-1.11.2/docs/conf.py000066400000000000000000000061201476700770300172030ustar00rootroot00000000000000# Configuration file for the Sphinx documentation builder. # # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import sys import os sys.path.insert(0, os.path.abspath('..')) project = 'array-api-compat' copyright = '2024, Consortium for Python Data API Standards' author = 'Consortium for Python Data API Standards' import array_api_compat release = array_api_compat.__version__ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ 'myst_parser', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.intersphinx', 'sphinx_copybutton', ] intersphinx_mapping = { 'cupy': ('https://docs.cupy.dev/en/stable', None), 'torch': ('https://pytorch.org/docs/stable/', None), } # Require :external: to reference intersphinx. intersphinx_disabled_reftypes = ['*'] templates_path = ['_templates'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] myst_enable_extensions = ["dollarmath", "linkify", "tasklist"] myst_enable_checkboxes = True napoleon_use_rtype = False napoleon_use_param = False # Make sphinx give errors for bad cross-references nitpicky = True # autodoc wants to make cross-references for every type hint. But a lot of # them don't actually refer to anything that we have a document for. nitpick_ignore = [ ("py:class", "Array"), ("py:class", "Device"), ] # Lets us use single backticks for code in RST default_role = 'code' # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = 'furo' html_static_path = ['_static'] html_css_files = ['custom.css'] html_theme_options = { # See https://pradyunsg.me/furo/customisation/footer/ "footer_icons": [ { "name": "GitHub", "url": "https://github.com/data-apis/array-api-compat", "html": """ """, "class": "", }, ], } # Logo html_favicon = "_static/favicon.png" # html_logo = "_static/logo.svg" array-api-compat-1.11.2/docs/dev/000077500000000000000000000000001476700770300164635ustar00rootroot00000000000000array-api-compat-1.11.2/docs/dev/implementation-notes.md000066400000000000000000000037221476700770300231640ustar00rootroot00000000000000# Implementation Notes Since NumPy, CuPy, and to a degree, Dask, are nearly identical in behavior, most wrapping logic can be shared between them. Wrapped functions that have the same logic between multiple libraries are in `array_api_compat/common/`. These functions are defined like ```py # In array_api_compat/common/_aliases.py def acos(x, /, xp): return xp.arccos(x) ``` The `xp` argument refers to the original array namespace (e.g., `numpy` or `cupy`). Then in the specific `array_api_compat/numpy/` and `array_api_compat/cupy/` namespaces, the `@get_xp` decorator is applied to these functions, which automatically removes the `xp` argument from the function signature and replaces it with the corresponding array library, like ```py # In array_api_compat/numpy/_aliases.py from ..common import _aliases import numpy as np acos = get_xp(np)(_aliases.acos) ``` This `acos` now has the signature `acos(x, /)` and calls `numpy.arccos`. Similarly, for CuPy: ```py # In array_api_compat/cupy/_aliases.py from ..common import _aliases import cupy as cp acos = get_xp(cp)(_aliases.acos) ``` Most NumPy and CuPy are defined in this way, since their behaviors are nearly identical PyTorch uses a similar layout in `array_api_compat/torch/`, but it differs enough from NumPy/CuPy that very few common wrappers for those libraries are reused. Dask is close to NumPy in behavior and so most Dask functions also reuse the NumPy/CuPy common wrappers. Occasionally, a wrapper implementation will need to reference another wrapper implementation, rather than the base `xp` version. The easiest way to do this is to call `array_namespace`, like ```py wrapped_xp = array_namespace(x) wrapped_xp.wrapped_func(...) ``` Also, if there is a very minor difference required for wrapping, say, CuPy and NumPy, they can still use a common implementation in `common/_aliases.py` and use the `is_*_namespace()` or `is_*_function()` [helper functions](../helper-functions.rst) to branch as necessary. array-api-compat-1.11.2/docs/dev/index.md000066400000000000000000000004211476700770300201110ustar00rootroot00000000000000# Development Notes This is internal documentation related to the development of array-api-compat. It is recommended that contributors read through this documentation. ```{toctree} :titlesonly: special-considerations.md implementation-notes.md tests.md releasing.md ``` array-api-compat-1.11.2/docs/dev/releasing.md000066400000000000000000000061251476700770300207620ustar00rootroot00000000000000# Releasing - [ ] **Create a PR with a release branch** This makes it easy to verify that CI is passing, and also gives you a place to push up updates to the changelog and any last minute fixes for the release. - [ ] **Double check the release branch is fully merged with `main`.** (e.g., if the release branch is called `release`) ``` git checkout main git pull git checkout release git merge main ``` - [ ] **Make sure that all CI tests are passing.** Note that the GitHub action that publishes to PyPI does not check if CI is passing before publishing. So you need to check this manually. This does mean you can ignore CI failures, but ideally you should fix any failures or update the `*-xfails.txt` files before tagging, so that CI and the CuPy tests fully pass. Otherwise it will be hard to tell what things are breaking in the future. It's also a good idea to remove any xpasses from those files (but be aware that some xfails are from flaky failures, so unless you know the underlying issue has been fixed, an xpass test is probably still xfail). - [ ] **Test CuPy.** CuPy must be tested manually (it isn't tested on CI, see https://github.com/data-apis/array-api-compat/issues/197). Use the script ``` ./test_cupy.sh ``` on a machine with a CUDA GPU. - [ ] **Update the version.** You must edit ``` array_api_compat/__init__.py ``` and update the version (the version is not computed from the tag because that would break vendorability). - [ ] **Update the [changelog](../changelog.md).** Edit ``` docs/changelog.md ``` with the changes for the release. - [ ] **Create the release tag.** Once everything is ready, create a tag ``` git tag -a ``` (note the tag names are not prefixed, for instance, the tag for version 1.5 is just `1.5`) - [ ] **Push the tag to GitHub.** *This is the final step. Doing this will build and publish the release!* ``` git push origin ``` This will trigger the [`publish distributions`](https://github.com/data-apis/array-api-compat/actions/workflows/publish-package.yml) GitHub Action that will build the release and push it to PyPI. - [ ] **Check that the [`publish distributions`](https://github.com/data-apis/array-api-compat/actions/workflows/publish-package.yml) action build on the tag worked.** Note that this action will run even if the other CI fails, so you must make sure that CI is passing *before* tagging. If it failed for some reason, you may need to delete the tag and try again. - [ ] **Merge the release branch.** This way any changes you made in the branch, such as updates to the changelog or xfails files, are updated in `main`. This will also make the docs update (the docs are published automatically from the sources on `main`). - [ ] **Update conda-forge.** After the PyPI package is published, the conda-forge bot should update the feedstock automatically after some time. The bot should automerge, so in most cases you don't need to do anything here, unless some metadata on the feedstock needs to be updated. array-api-compat-1.11.2/docs/dev/special-considerations.md000066400000000000000000000107431476700770300234540ustar00rootroot00000000000000# Special Considerations array-api-compat requires some special development considerations that are different from most other Python libraries. The goal of array-api-compat is to be a small library that packages can either vendor or add as a dependency to implement array API support. Consequently, certain design considerations should be taken into account: (no-dependencies)= - **No Hard Dependencies.** Although array-api-compat "depends" on NumPy, CuPy, PyTorch, etc., it does not hard depend on them. These libraries are not imported unless either an array object is passed to {func}`~.array_namespace()`, or the specific `array_api_compat.` sub-namespace is explicitly imported. This is tested (as best as possible) in `tests/test_no_dependencies.py`. - **Vendorability.** array-api-compat should be [vendorable](vendoring). This means that, for instance, all imports in the library are relative imports. No code in the package specifically references the name `array_api_compat` (we also support renaming the package to something else). Vendorability support is tested in `tests/test_vendoring.py`. - **Pure Python.** To make array-api-compat as easy as possible to add as a dependency, the code is all pure Python. - **Minimal Wrapping Only.** The wrapping functionality is minimal. This means that if something is difficult to wrap using pure Python, or if trying to support some array API behavior would require a significant amount of code, we prefer to leave the behavior as an upstream issue for the array library, and [document it as a known difference](../supported-array-libraries.md). This also means that we do not at this point in time implement anything other than wrappers for functions in the standard, and basic [helper functions](../helper-functions.rst) that would be useful for most users of array-api-compat. The addition of functions that are not part of the array API standard is currently out-of-scope for this package (see the [Scope](scope) section of the documentation). - **No Side-Effects**. array-api-compat behavior should be localized to only the specific code that imports and uses it. It should be invisible to end-users or users of dependent codes. This in particular implies to the next two points. - **No Monkey Patching.** `array-api-compat` should not attempt to modify anything about the underlying library. It is a *wrapper* library only. - **No Modifying the Array Object.** The array (or tensor) object of the array library cannot be modified. This also precludes the creation of array subclasses or wrapper classes. Any non-standard behavior that is built-in to the array object, such as the behavior of [array methods](https://data-apis.org/array-api/latest/API_specification/array_object.html), is therefore left unwrapped. Users can workaround issues by using corresponding [elementwise functions](https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html) instead of [operators](https://data-apis.org/array-api/latest/API_specification/array_object.html#operators), and by using the [helper functions](../helper-functions.rst) provided by array-api-compat instead of attributes or methods like `x.to_device()`. - **Avoid Restricting Behavior that is Outside the Scope of the Standard.** All array libraries have functions and behaviors that are outside of the scope of what is specified by the standard. These behaviors should be left intact whenever possible, unless the standard explicitly disallows something. This means - All namespaces are *extended* with wrapper functions. You may notice the extensive use of `import *` in various files in `array_api_compat`. While this would normally be questionable, this is the [one actual legitimate use-case for `import *`](https://peps.python.org/pep-0008/#imports), to re-export names from an external namespace. - All wrapper functions pass `**kwargs` through to the wrapped function. - Input types not supported by the standard should work if they work in the underlying wrapped function (for instance, Python scalars or `np.ndarray` subclasses). By keeping underlying behaviors intact, it is easier for libraries to swap out NumPy or other array libraries for array-api-compat, and it is easier for libraries to write array library-specific code paths. The onus is on users of array-api-compat to ensure their array API code is portable, e.g., by testing against [array-api-strict](array-api-strict). array-api-compat-1.11.2/docs/dev/tests.md000066400000000000000000000034531476700770300201540ustar00rootroot00000000000000# Tests The majority of the behavior for array-api-compat is tested by the [array-api-tests](https://github.com/data-apis/array-api-tests) test suite for the array API standard. There are also array-api-compat specific tests in [`tests/`](https://github.com/data-apis/array-api-compat/tree/main/tests). These tests should be limited to things that are not tested by the test suite, e.g., tests for [helper functions](../helper-functions.rst) or for behavior that is not strictly required by the standard. To run these tests, install the dependencies from `requirements-dev.txt` (array-api-compat has [no hard runtime dependencies](no-dependencies)). array-api-tests is run against all supported libraries are tested on CI ([except for JAX](jax-support) and [Sparse](sparse-support)). This is achieved by a [reusable GitHub Actions Workflow](https://github.com/data-apis/array-api-compat/blob/main/.github/workflows/array-api-tests.yml). Most libraries have tests that must be xfailed or skipped for various reasons. These are defined in specific `-xfails.txt` files and are automatically forwarded to array-api-tests. You may often need to update these xfail files, either to add new xfails (e.g., because of new test suite features, or because a test that was previously thought to be passing actually flaky fails). Try to keep the xfails files organized, with comments pointing to upstream issues whenever possible. From time to time, xpass tests should be removed from the xfail files, but be aware that many xfail tests are flaky, so an xpass should only be removed if you know that the underlying issue has been fixed. Array libraries that require a GPU to run (currently only CuPy) cannot be tested on CI. There is a helper script `test_cupy.sh` that can be used to manually test CuPy on a machine with a CUDA GPU. array-api-compat-1.11.2/docs/helper-functions.rst000066400000000000000000000041261476700770300217270ustar00rootroot00000000000000Helper Functions ================ .. currentmodule:: array_api_compat In addition to the wrapped library namespaces and functions in the array API specification, there are several helper functions included here that aren't part of the specification but which are useful for using the array API: Entry-point Helpers ------------------- The `array_namespace()` function is the primary entry-point for array API consuming libraries. .. autofunction:: array_namespace .. autofunction:: is_array_api_obj Array Method Helpers -------------------- array-api-compat does not attempt to wrap or monkey patch the array object for any library. Consequently, any API differences for the `array object `__ cannot be directly wrapped. Some libraries do not define some of these methods or define them differently. For these, helper functions are provided which can be used instead. Note that if you have a compatibility issue with an operator method (like `__add__`, i.e., `+`) you can prefer to use the corresponding `elementwise function `__ instead, which would be wrapped. .. autofunction:: device .. autofunction:: to_device .. autofunction:: size Inspection Helpers ------------------ These convenience functions can be used to test if an array or namespace comes from a specific library without importing that library if it hasn't been imported yet. .. autofunction:: is_numpy_array .. autofunction:: is_cupy_array .. autofunction:: is_torch_array .. autofunction:: is_dask_array .. autofunction:: is_jax_array .. autofunction:: is_pydata_sparse_array .. autofunction:: is_ndonnx_array .. autofunction:: is_writeable_array .. autofunction:: is_lazy_array .. autofunction:: is_numpy_namespace .. autofunction:: is_cupy_namespace .. autofunction:: is_torch_namespace .. autofunction:: is_dask_namespace .. autofunction:: is_jax_namespace .. autofunction:: is_pydata_sparse_namespace .. autofunction:: is_ndonnx_namespace .. autofunction:: is_array_api_strict_namespace array-api-compat-1.11.2/docs/index.md000066400000000000000000000171221476700770300173410ustar00rootroot00000000000000# Array API compatibility library This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, and Sparse are supported. If you want support for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). Note that some of the functionality in this library is backwards incompatible with the corresponding wrapped libraries. The end-goal is to eventually make each array library itself fully compatible with the array API, but this requires making backwards incompatible changes in many cases, so this will take some time. Currently all libraries here are implemented against the [2024.12 version](https://data-apis.org/array-api/2024.12/) of the standard. ## Installation `array-api-compat` is available on both [PyPI](https://pypi.org/project/array-api-compat/) ``` python -m pip install array-api-compat ``` and [conda-forge](https://anaconda.org/conda-forge/array-api-compat) ``` conda install --channel conda-forge array-api-compat ``` ## Usage The typical usage of this library will be to get the corresponding array API compliant namespace from the input arrays using {func}`~.array_namespace()`, like ```py def your_function(x, y): xp = array_api_compat.array_namespace(x, y) # Now use xp as the array library namespace return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) ``` If you wish to have library-specific code-paths, you can import the corresponding wrapped namespace for each library, like ```py import array_api_compat.numpy as np ``` ```py import array_api_compat.cupy as cp ``` ```py import array_api_compat.torch as torch ``` ```py import array_api_compat.dask as da ``` ```{note} There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These support for these libraries is contained in the libraries themselves (JAX support is in the `jax.numpy` module in JAX v0.4.32 or newer, and in the `jax.experimental.array_api` module for older JAX versions). The array-api-compat support for these libraries consists of supporting them in the [helper functions](helper-functions). ``` Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array namespace, except that functions that are part of the array API are wrapped so that they have the correct array API behavior. In each case, the array object used will be the same array object from the wrapped library. (array-api-strict)= ## Difference between `array_api_compat` and `array_api_strict` [`array_api_strict`](https://data-apis.org/array-api-strict/) is a strict minimal implementation of the array API standard, formerly known as `numpy.array_api` (see [NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html)). For example, `array_api_strict` does not include any functions that are not part of the array API specification, and will explicitly disallow behaviors that are not required by the spec (e.g., [cross-kind type promotions](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)). (`cupy.array_api` is similar to `array_api_strict`) `array_api_compat`, on the other hand, is just an extension of the corresponding array library namespaces with changes needed to be compliant with the array API. It includes all additional library functions not mentioned in the spec, and allows any library behaviors not explicitly disallowed by it, such as cross-kind casting. In particular, unlike `array_api_strict`, this package does not use a separate `Array` object, but rather just uses the corresponding array library array objects (`numpy.ndarray`, `cupy.ndarray`, `torch.Tensor`, etc.) directly. This is because those are the objects that are going to be passed as inputs to functions by end users. This does mean that a few behaviors cannot be wrapped (see below), but most of the array API functional, so this does not affect most things. Array consuming library authors coding against the array API may wish to test against `array_api_strict` to ensure they are not using functionality outside of the standard, but prefer this implementation for the default behavior for end-users. (vendoring)= ## Vendoring This library supports vendoring as an installation method. To vendor the library, simply copy `array_api_compat` into the appropriate place in the library, like ``` cp -R array_api_compat/ mylib/vendored/array_api_compat ``` You may also rename it to something else if you like (nowhere in the code references the name "array_api_compat"). Alternatively, the library may be installed as dependency from PyPI. (scope)= ## Scope At this time, the scope of array-api-compat is limited to wrapping array libraries so that they can comply with the [array API standard](https://data-apis.org/array-api/latest/API_specification/index.html). This includes a small set of [helper functions](helper-functions.rst) which may be useful to most users of array-api-compat, for instance, functions that provide meta-functionality to aid in supporting the array API, or functions that are necessary to work around wrapping limitations for certain libraries. Things that are out-of-scope include: - functions that have not yet been standardized (although note that functions that are in a draft version of the standard are *in scope*), - functions that are complicated to implement correctly/maintain, - anything that requires the use of non-Python code. If you want a function that is not in array-api-compat that isn't part of the standard, you should request it either for [inclusion in the standard](https://github.com/data-apis/array-api/issues) or in specific array libraries. Why is the scope limited in this way? Firstly, we want to keep array-api-compat as primarily a [polyfill](https://en.wikipedia.org/wiki/Polyfill_(programming)) compatibility shim. The goal is to let consuming libraries use the array API today, even with array libraries that do not yet fully support it. In an ideal world---one that we hope to eventually see in the future---array-api-compat would be unnecessary, because every array library would fully support the standard. The inclusion of non-standardized functions in array-api-compat would undermine this goal. But much more importantly, it would also undermine the goals of the [Data APIs Consortium](https://data-apis.org/). The Consortium creates the array API standard via the consensus of stakeholders from various array libraries and users. If a not-yet-standardized function were included in array-api-compat, it would become *de facto* standard, bypassing the decision making processes of the Consortium. Secondly, we want to keep array-api-compat as minimal as possible, so that it is easy for libraries to add as a (possibly vendored) dependency. Thirdly, array-api-compat has a relatively small development team. Pull requests to array-api-compat would not necessarily receive the same stringent level of scrutiny that changes to established array libraries like NumPy or PyTorch would. For wrapped standard functions, this is fine, since the wrappers typically just clean up a few small inconsistencies from the standard, leaving the complexity of the implementation to the base array library function. Furthermore, standard functions are tested by the rigorous [array-api-tests](https://github.com/data-apis/array-api-tests) test suite. For this reason, functions that require complex implementations are generally out-of-scope and should be preferred to be implemented in upstream array libraries. ```{toctree} :titlesonly: :hidden: helper-functions.rst supported-array-libraries.md changelog.md dev/index.md ``` array-api-compat-1.11.2/docs/make.bat000066400000000000000000000014401476700770300173110ustar00rootroot00000000000000@ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) if "%1" == "" goto help %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd array-api-compat-1.11.2/docs/requirements.txt000066400000000000000000000001111476700770300211620ustar00rootroot00000000000000furo linkify-it-py myst-parser sphinx sphinx-copybutton sphinx-autobuild array-api-compat-1.11.2/docs/supported-array-libraries.md000066400000000000000000000150111476700770300233400ustar00rootroot00000000000000# Supported Array Libraries The following array libraries are supported. This page outlines the known differences between this library and the array API specification for the supported packages. Note that the {func}`~.array_namespace()` helper will also support any array library that explicitly supports the array API by defining [`__array_namespace__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html). Any reasonably popular array library is in-scope for array-api-compat, assuming it is possible to wrap it to support the array API without too much complexity. If your favorite library is not supported, feel free to open an [issue or pull request](https://github.com/data-apis/array-api-compat/issues). ## [NumPy](https://numpy.org/) and [CuPy](https://cupy.dev/) NumPy 2.0 has full array API compatibility. This package is not strictly necessary for NumPy 2.0 support, but may still be useful for the support of other libraries, as well as for the [helper functions](helper-functions.rst). For NumPy 1.26, as well as corresponding versions of CuPy, the following deviations from the standard should be noted: - The array methods `__array_namespace__`, `device` (for NumPy), `to_device`, and `mT` are not defined. This reuses `np.ndarray` and `cp.ndarray` and we don't want to monkey patch or wrap it. The [helper functions](helper-functions.rst) {func}`~.device()` and {func}`~.to_device()` are provided to work around these missing methods. `x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`. {func}`~.array_namespace()` should be used instead of `x.__array_namespace__`. - Value-based casting for scalars will be in effect unless explicitly disabled with the environment variable `NPY_PROMOTION_STATE=weak` or `np._set_promotion_state('weak')` (requires NumPy 1.24 or newer, see [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and https://github.com/numpy/numpy/issues/22341) - `asarray()` does not support `copy=False`. - Functions which are not wrapped may not have the same type annotations as the spec. - Functions which are not wrapped may not use positional-only arguments. The minimum supported NumPy version is 1.21. However, this older version of NumPy has a few issues: - `unique_*` will not compare nans as unequal. - `finfo()` has no `smallest_normal`. - No `from_dlpack` or `__dlpack__`. - `argmax()` and `argmin()` do not have `keepdims`. - `qr()` doesn't support matrix stacks. - `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not supported even in the latest NumPy). - Type promotion behavior will be value based for 0-D arrays (and there is no `NPY_PROMOTION_STATE=weak` to disable this). If any of these are an issue, it is recommended to bump your minimum NumPy version. ## [PyTorch](https://pytorch.org/) - Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the `__array_namespace__` and `to_device` methods, so the corresponding helper functions {func}`~.array_namespace()` and {func}`~.to_device()` in this library should be used instead. - The {external+torch:meth}`x.size() ` attribute on `torch.Tensor` is a method that behaves differently from the [`x.size`](https://data-apis.org/array-api/draft/API_specification/generated/array_api.array.size.html) attribute in the spec. Use the {func}`~.size()` helper function as a portable workaround. - PyTorch does not have unsigned integer types other than `uint8`, and no attempt is made to implement them here. - PyTorch has type promotion semantics that differ from the array API specification for 0-D tensor objects. The array functions in this wrapper library do work around this, but the operators on the Tensor object do not, as no operators or methods on the Tensor object are modified. If this is a concern, use the functional form instead of the operator form, e.g., `add(x, y)` instead of `x + y`. - [`unique_all()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html#array_api.unique_all) is not implemented, due to the fact that `torch.unique` does not support returning the `indices` array. The other [`unique_*`](https://data-apis.org/array-api/latest/API_specification/set_functions.html) functions are implemented. - Slices do not support negative steps. - [`std()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html#array_api.std) and [`var()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html#array_api.var) do not support floating-point `correction`. - The `stream` argument of the {func}`~.to_device()` helper is not supported. - As with NumPy, type annotations and positional-only arguments may not exactly match the spec for functions that are not wrapped at all. The minimum supported PyTorch version is 1.13. (jax-support)= ## [JAX](https://jax.readthedocs.io/en/latest/) Unlike the other libraries supported here, JAX array API support is contained entirely in the JAX library. The JAX array API support is tracked at https://github.com/google/jax/issues/18353. ## [Dask](https://www.dask.org/) If you're using dask with numpy, many of the same limitations that apply to numpy will also apply to dask. Besides those differences, other limitations include missing sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg` and `fft` extensions. In particular, the `fft` namespace is not compliant with the array API spec. Any functions that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk! For `linalg`, several methods are missing, for example: - `cross` - `det` - `eigh` - `eigvalsh` - `matrix_power` - `pinv` - `slogdet` - `matrix_norm` - `matrix_rank` Other methods may only be partially implemented or return incorrect results at times. The minimum supported Dask version is 2023.12.0. (sparse-support)= ## [Sparse](https://sparse.pydata.org/en/stable/) Similar to JAX, `sparse` Array API support is contained directly in `sparse`. (ndonnx-support)= ## [ndonnx](https://github.com/quantco/ndonnx) Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`. (array-api-strict-support)= ## [array-api-strict](https://data-apis.org/array-api-strict/) array-api-strict exists only to test support for the Array API, so it does not need any wrappers. array-api-compat-1.11.2/numpy-1-21-xfails.txt000066400000000000000000000440131476700770300205320ustar00rootroot00000000000000# asarray(copy=False) is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] array_api_tests/test_signatures.py::test_array_method_signature[to_device] # NumPy deviates in some special cases for floordiv array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] # https://github.com/numpy/numpy/issues/21213 array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # NumPy 1.21 specific XFAILS ############################ # finfo has no smallest_normal array_api_tests/test_data_type_functions.py::test_finfo[float64] # dlpack stuff array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] # qr() doesn't support matrix stacks array_api_tests/test_linalg.py::test_qr # cross has some promotion bug that is fixed in newer numpy versions array_api_tests/test_linalg.py::test_cross # vector_norm with ord=-1 which has since been fixed # https://github.com/numpy/numpy/issues/21083 array_api_tests/test_linalg.py::test_vector_norm # argmax and argmin do not support keepdims array_api_tests/test_searching_functions.py::test_argmax array_api_tests/test_searching_functions.py::test_argmin array_api_tests/test_signatures.py::test_func_signature[argmax] array_api_tests/test_signatures.py::test_func_signature[argmin] # NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with # type promotion issues array_api_tests/test_manipulation_functions.py::test_concat array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_atan2 array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_copysign array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_hypot array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp array_api_tests/test_operators_and_elementwise_functions.py::test_maximum array_api_tests/test_operators_and_elementwise_functions.py::test_minimum array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is +0) -> roughly -pi/2] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is +0) -> roughly +pi/2] array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # 2023.12 support array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array-api-compat-1.11.2/numpy-1-26-xfails.txt000066400000000000000000000141001476700770300205310ustar00rootroot00000000000000# finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] array_api_tests/test_signatures.py::test_array_method_signature[to_device] # NumPy deviates in some special cases for floordiv array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] # https://github.com/numpy/numpy/issues/21213 array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # 2023.12 support array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array-api-compat-1.11.2/numpy-dev-xfails.txt000066400000000000000000000062741476700770300207370ustar00rootroot00000000000000# finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array-api-compat-1.11.2/numpy-skips.txt000066400000000000000000000020201476700770300200070ustar00rootroot00000000000000# These tests cause a core dump on CI, so we have to skip them entirely array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] array-api-compat-1.11.2/numpy-xfails.txt000066400000000000000000000063201476700770300201530ustar00rootroot00000000000000# finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array-api-compat-1.11.2/requirements-dev.txt000066400000000000000000000001201476700770300210060ustar00rootroot00000000000000array-api-strict dask[array] jax[cpu] numpy pytest torch sparse >=0.15.1 ndonnx array-api-compat-1.11.2/ruff.toml000066400000000000000000000003441476700770300166150ustar00rootroot00000000000000[lint] preview = true select = [ # Defaults "E4", "E7", "E9", "F", # Undefined export "F822", # Useless import alias "PLC0414" ] ignore = [ # Module import not at top of file "E402", # Do not use bare `except` "E722" ] array-api-compat-1.11.2/setup.py000066400000000000000000000023411476700770300164670ustar00rootroot00000000000000from setuptools import setup, find_packages with open("README.md", "r") as fh: long_description = fh.read() import array_api_compat setup( name='array_api_compat', version=array_api_compat.__version__, packages=find_packages(include=["array_api_compat*"]), author="Consortium for Python Data API Standards", description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", long_description=long_description, long_description_content_type="text/markdown", url="https://data-apis.org/array-api-compat/", license="MIT", extras_require={ "numpy": "numpy", "cupy": "cupy", "jax": "jax", "pytorch": "pytorch", "dask": "dask", "sparse": "sparse >=0.15.1", }, python_requires=">=3.9", classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] ) array-api-compat-1.11.2/test_cupy.sh000077500000000000000000000016751476700770300173440ustar00rootroot00000000000000#!/usr/bin/env bash # We cannot test cupy on CI so this script will test it manually. Assumes it # is being run in an environment that has cupy and the array-api-tests # dependencies installed set -x set -e # Run the vendoring tests in this repo pytest tmpdir=$(mktemp -d) SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) export PYTHONPATH="$PYTHONPATH:$SCRIPT_DIR" PYTEST_ARGS="--max-examples 200 -v -rxXfE --ci --hypothesis-disable-deadline" cd $tmpdir git clone https://github.com/data-apis/array-api-tests cd array-api-tests git submodule update --init # store the hypothesis examples database in this directory, so that failures # will be remembered across runs mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy export ARRAY_API_TESTS_VERSION=2024.12 pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" array-api-compat-1.11.2/tests/000077500000000000000000000000001476700770300161175ustar00rootroot00000000000000array-api-compat-1.11.2/tests/__init__.py000066400000000000000000000003131476700770300202250ustar00rootroot00000000000000""" Basic tests for the compat library This only tests basic things like that vendoring works. The extensive tests are done by the array API test suite https://github.com/data-apis/array-api-tests """ array-api-compat-1.11.2/tests/_helpers.py000066400000000000000000000021601476700770300202710ustar00rootroot00000000000000from importlib import import_module import pytest wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"] all_libraries = wrapped_libraries + [ "array_api_strict", "jax.numpy", "ndonnx", "sparse" ] def import_(library, wrapper=False): pytest.importorskip(library) if wrapper: if 'jax' in library: # JAX v0.4.32 implements the array API directly in jax.numpy # Older jax versions use jax.experimental.array_api jax_numpy = import_module("jax.numpy") if not hasattr(jax_numpy, "__array_api_version__"): library = 'jax.experimental.array_api' elif library in wrapped_libraries: library = 'array_api_compat.' + library return import_module(library) def xfail(request: pytest.FixtureRequest, reason: str) -> None: """ XFAIL the currently running test. Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately halting it, so that it may result in a XPASS. xref https://github.com/pandas-dev/pandas/issues/38902 """ request.node.add_marker(pytest.mark.xfail(reason=reason)) array-api-compat-1.11.2/tests/test_all.py000066400000000000000000000034211476700770300203000ustar00rootroot00000000000000""" Test that files that define __all__ aren't missing any exports. You can add names that shouldn't be exported to _all_ignore, like _all_ignore = ['sys'] This is preferable to del-ing the names as this will break any name that is used inside of a function. Note that names starting with an underscore are automatically ignored. """ import sys from ._helpers import import_, wrapped_libraries import pytest @pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": import array_api_compat.common # noqa: F401 else: import_(library, wrapper=True) for mod_name in sys.modules: if not mod_name.startswith('array_api_compat.' + library): continue module = sys.modules[mod_name] # TODO: We should define __all__ in the __init__.py files and test it # there too. if not hasattr(module, '__all__'): continue dir_names = [n for n in dir(module) if not n.startswith('_')] if '__array_namespace_info__' in dir(module): dir_names.append('__array_namespace_info__') ignore_all_names = getattr(module, '_all_ignore', []) ignore_all_names += ['annotations', 'TYPE_CHECKING'] dir_names = set(dir_names) - set(ignore_all_names) all_names = module.__all__ if set(dir_names) != set(all_names): extra_dir = set(dir_names) - set(all_names) extra_all = set(all_names) - set(dir_names) assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" array-api-compat-1.11.2/tests/test_array_namespace.py000066400000000000000000000116571476700770300226740ustar00rootroot00000000000000import subprocess import sys import warnings import jax import numpy as np import pytest import torch import array_api_compat from array_api_compat import array_namespace from ._helpers import import_, all_libraries, wrapped_libraries @pytest.mark.parametrize("use_compat", [True, False, None]) @pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) @pytest.mark.parametrize("library", all_libraries) def test_array_namespace(library, api_version, use_compat): xp = import_(library) array = xp.asarray([1.0, 2.0, 3.0]) if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return if library == "ndonnx" and api_version in ("2021.12", "2022.12"): pytest.skip("Unsupported API version") namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: import jax.numpy if hasattr(jax.numpy, "__array_api_version__"): # JAX v0.4.32 or later uses jax.numpy directly assert namespace == jax.numpy else: # JAX v0.4.31 or earlier uses jax.experimental.array_api import jax.experimental.array_api assert namespace == jax.experimental.array_api else: assert namespace == xp else: if library == "dask.array": assert namespace == array_api_compat.dask.array else: assert namespace == getattr(array_api_compat, library) if library == "numpy": # check that the same namespace is returned for NumPy scalars scalar_namespace = array_namespace( xp.float64(0.0), api_version=api_version, use_compat=use_compat ) assert scalar_namespace == namespace # Check that array_namespace works even if jax.experimental.array_api # hasn't been imported yet (it monkeypatches __array_namespace__ # onto JAX arrays, but we should support them regardless). The only way to # do this is to use a subprocess, since we cannot un-import it and another # test probably already imported it. if library == "jax.numpy" and sys.version_info >= (3, 9): code = f"""\ import sys import jax.numpy import array_api_compat array = jax.numpy.asarray([1.0, 2.0, 3.0]) assert 'jax.experimental.array_api' not in sys.modules namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) if hasattr(jax.numpy, '__array_api_version__'): assert namespace == jax.numpy else: import jax.experimental.array_api assert namespace == jax.experimental.array_api """ subprocess.run([sys.executable, "-c", code], check=True) def test_jax_zero_gradient(): jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) def test_array_namespace_errors_torch(): y = torch.asarray([1, 2]) x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) def test_api_version_torch(): x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) assert array_namespace(x, api_version="2023.12") == torch_ assert array_namespace(x, api_version=None) == torch_ assert array_namespace(x) == torch_ # Should issue a warning with warnings.catch_warnings(record=True) as w: assert array_namespace(x, api_version="2021.12") == torch_ assert len(w) == 1 assert "2021.12" in str(w[0].message) # Should issue a warning with warnings.catch_warnings(record=True) as w: assert array_namespace(x, api_version="2022.12") == torch_ assert len(w) == 1 assert "2022.12" in str(w[0].message) pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_namespace def test_python_scalars(): a = torch.asarray([1, 2]) xp = import_("torch", wrapper=True) pytest.raises(TypeError, lambda: array_namespace(1)) pytest.raises(TypeError, lambda: array_namespace(1.0)) pytest.raises(TypeError, lambda: array_namespace(1j)) pytest.raises(TypeError, lambda: array_namespace(True)) pytest.raises(TypeError, lambda: array_namespace(None)) assert array_namespace(a, 1) == xp assert array_namespace(a, 1.0) == xp assert array_namespace(a, 1j) == xp assert array_namespace(a, True) == xp assert array_namespace(a, None) == xp array-api-compat-1.11.2/tests/test_common.py000066400000000000000000000304611476700770300210240ustar00rootroot00000000000000import math import pytest import numpy as np import array from numpy.testing import assert_equal from array_api_compat import ( # noqa: F401 is_numpy_array, is_cupy_array, is_torch_array, is_dask_array, is_jax_array, is_pydata_sparse_array, is_ndonnx_array, is_numpy_namespace, is_cupy_namespace, is_torch_namespace, is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, is_array_api_strict_namespace, is_ndonnx_namespace, ) from array_api_compat import ( device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) from ._helpers import all_libraries, import_, wrapped_libraries, xfail is_array_functions = { 'numpy': 'is_numpy_array', 'cupy': 'is_cupy_array', 'torch': 'is_torch_array', 'dask.array': 'is_dask_array', 'jax.numpy': 'is_jax_array', 'sparse': 'is_pydata_sparse_array', 'ndonnx': 'is_ndonnx_array', } is_namespace_functions = { 'numpy': 'is_numpy_namespace', 'cupy': 'is_cupy_namespace', 'torch': 'is_torch_namespace', 'dask.array': 'is_dask_namespace', 'jax.numpy': 'is_jax_namespace', 'sparse': 'is_pydata_sparse_namespace', 'array_api_strict': 'is_array_api_strict_namespace', 'ndonnx': 'is_ndonnx_namespace', } @pytest.mark.parametrize('library', is_array_functions.keys()) @pytest.mark.parametrize('func', is_array_functions.values()) def test_is_xp_array(library, func): lib = import_(library) is_func = globals()[func] x = lib.asarray([1, 2, 3]) assert is_func(x) == (func == is_array_functions[library]) assert is_array_api_obj(x) @pytest.mark.parametrize('library', is_namespace_functions.keys()) @pytest.mark.parametrize('func', is_namespace_functions.values()) def test_is_xp_namespace(library, func): lib = import_(library) is_func = globals()[func] assert is_func(lib) == (func == is_namespace_functions[library]) @pytest.mark.parametrize('library', all_libraries) def test_xp_is_array_generics(library): """ Test that scalar selection on a xp.ndarray always returns an object that matches with exactly one among the is_*_array function of the same library and is_numpy_array. """ lib = import_(library) x = lib.asarray([1, 2, 3]) x0 = x[0] matches = [] for library2, func in is_array_functions.items(): is_func = globals()[func] if is_func(x0): matches.append(library2) if library == "array_api_strict": # There is no is_array_api_strict_array() function assert matches == [] else: assert matches in ([library], ["numpy"]) @pytest.mark.parametrize("library", all_libraries) def test_is_writeable_array(library): lib = import_(library) x = lib.asarray([1, 2, 3]) if is_writeable_array(x): x[1] = 4 else: with pytest.raises((TypeError, ValueError)): x[1] = 4 def test_is_writeable_array_numpy(): x = np.asarray([1, 2, 3]) assert is_writeable_array(x) x.flags.writeable = False assert not is_writeable_array(x) @pytest.mark.parametrize("library", all_libraries) def test_size(library): xp = import_(library) x = xp.asarray([1, 2, 3]) assert size(x) == 3 @pytest.mark.parametrize("library", all_libraries) def test_size_none(library): if library == "sparse": pytest.skip("No arange(); no indexing by sparse arrays") xp = import_(library) x = xp.arange(10) x = x[x < 5] # dask.array now has shape=(nan, ) and size=nan # ndonnx now has shape=(None, ) and size=None # Eager libraries have shape=(5, ) and size=5 assert size(x) in (None, 5) @pytest.mark.parametrize("library", all_libraries) def test_is_lazy_array(library): lib = import_(library) x = lib.asarray([1, 2, 3]) assert isinstance(is_lazy_array(x), bool) @pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)]) def test_is_lazy_array_nan_size(shape, monkeypatch): """Test is_lazy_array() on an unknown Array API compliant object with NaN (like Dask) or None (like ndonnx) in its shape """ xp = import_("array_api_strict") x = xp.asarray(1) assert not is_lazy_array(x) monkeypatch.setattr(type(x), "shape", shape) assert is_lazy_array(x) @pytest.mark.parametrize("exc", [TypeError, AssertionError]) def test_is_lazy_array_bool_raises(exc, monkeypatch): """Test is_lazy_array() on an unknown Array API compliant object where calling bool() raises: - TypeError: e.g. like jitted JAX. This is the proper exception which lazy arrays should raise as per the Array API specification - something else: e.g. like Dask, where bool() triggers compute() which can result in any kind of exception to be raised """ xp = import_("array_api_strict") x = xp.asarray(1) assert not is_lazy_array(x) def __bool__(self): raise exc("Hello world") monkeypatch.setattr(type(x), "__bool__", __bool__) assert is_lazy_array(x) @pytest.mark.parametrize( 'func', list(is_array_functions.values()) + ["is_array_api_obj", "is_lazy_array", "is_writeable_array"] ) def test_is_array_any_object(func): """Test that is_*_array functions return False and don't raise on non-array objects """ func = globals()[func] # These objects are missing attributes such as __name__ assert not func(object()) assert not func(None) assert not func(1) class C: pass assert not func(C()) @pytest.mark.parametrize("library", all_libraries) def test_device(library, request): if library == "ndonnx": xfail(request, reason="Needs ndonnx >=0.9.4") xp = import_(library, wrapper=True) # We can't test much for device() and to_device() other than that # x.to_device(x.device) works. x = xp.asarray([1, 2, 3]) dev = device(x) x2 = to_device(x, dev) assert device(x2) == device(x) x3 = xp.asarray(x, device=dev) assert device(x3) == device(x) @pytest.mark.parametrize("library", wrapped_libraries) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 xp = import_(library, wrapper=True) expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) x = to_device(x, "cpu") # torch will return a genuine Device object, but # the other libs will do something different with # a `device(x)` query; however, what's really important # here is that we can test portably after calling # to_device(x, "cpu") to return to host assert_equal(x, expected) @pytest.mark.parametrize("target_library", is_array_functions.keys()) @pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): if source_library == "dask.array" and target_library == "torch": # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved xfail(request, reason="Bug in dask raising error on conversion") elif ( source_library == "ndonnx" and target_library not in ("array_api_strict", "ndonnx", "numpy") ): xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") elif source_library == "ndonnx" and target_library == "numpy": xfail(request, reason="produces numpy array of ndonnx scalar arrays") elif source_library == "jax.numpy" and target_library == "torch": xfail(request, reason="casts int to float") elif source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU") elif source_library == "sparse" and target_library != "sparse": pytest.skip(reason="`sparse` does not allow implicit densification") src_lib = import_(source_library, wrapper=True) tgt_lib = import_(target_library, wrapper=True) is_tgt_type = globals()[is_array_functions[target_library]] a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32) b = tgt_lib.asarray(a) assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" assert b.dtype == tgt_lib.int32 @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't # test the copy flag to asarray() very rigorously. Once # https://github.com/data-apis/array-api-tests/issues/241 is fixed we # should be able to delete this. xp = import_(library, wrapper=True) asarray = xp.asarray is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'cupy': supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'dask.array': supports_copy_false_other_ns = False supports_copy_false_same_ns = True else: supports_copy_false_other_ns = True supports_copy_false_same_ns = True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0 assert all(b[0] == 1) assert all(a[0] == 0) a = asarray([1]) if supports_copy_false_same_ns: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0 assert all(b[0] == 0) else: pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) a = asarray([1]) if supports_copy_false_same_ns: pytest.raises(ValueError, lambda: asarray(a, copy=False, dtype=xp.float64)) else: pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 assert all(b[0] == 0) a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 assert all(b[0] == 1.0) a = asarray([1.0], dtype=xp.float64) assert a.dtype == xp.float64 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 assert all(b[0] == 0.0) # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: asarray(obj, copy=True) # No error asarray(obj, copy=None) # No error if supports_copy_false_other_ns: pytest.raises(ValueError, lambda: asarray(obj, copy=False)) else: pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) # Use the standard library array to test the buffer protocol a = array.array('f', [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 assert all(b[0] == 1.0) a = array.array('f', [1.0]) if supports_copy_false_other_ns: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 assert all(b[0] == 0.0) else: pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) a = array.array('f', [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 if library in ('cupy', 'dask.array'): # A copy is required for libraries where the default device is not CPU # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. # https://github.com/dask/dask/pull/11524/ assert all(b[0] == 1.0) else: assert all(b[0] == 0.0) @pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) def test_clip_out(library): """Test non-standard out= parameter for clip() (see "Avoid Restricting Behavior that is Outside the Scope of the Standard" in https://data-apis.org/array-api-compat/dev/special-considerations.html) """ xp = import_(library, wrapper=True) x = xp.asarray([10, 20, 30]) out = xp.zeros_like(x) xp.clip(x, 15, 25, out=out) expect = xp.asarray([15, 20, 25]) assert xp.all(out == expect) array-api-compat-1.11.2/tests/test_dask.py000066400000000000000000000123571476700770300204620ustar00rootroot00000000000000from contextlib import contextmanager import array_api_strict import dask import numpy as np import pytest import dask.array as da from array_api_compat import array_namespace @pytest.fixture def xp(): """Fixture returning the wrapped dask namespace""" return array_namespace(da.empty(0)) @contextmanager def assert_no_compute(): """ Context manager that raises if at any point inside it anything calls compute() or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc. """ def get(dsk, *args, **kwargs): raise AssertionError("Called compute() or persist()") with dask.config.set(scheduler=get): yield def test_assert_no_compute(): """Test the assert_no_compute context manager""" a = da.asarray(True) with pytest.raises(AssertionError, match="Called compute"): with assert_no_compute(): bool(a) # Exiting the context manager restores the original scheduler assert bool(a) is True # Test no_compute for functions that use generic _aliases with xp=np def test_unary_ops_no_compute(xp): with assert_no_compute(): a = xp.asarray([1.5, -1.5]) xp.ceil(a) xp.floor(a) xp.trunc(a) xp.sign(a) def test_matmul_tensordot_no_compute(xp): A = da.ones((4, 4), chunks=2) B = da.zeros((4, 4), chunks=2) with assert_no_compute(): xp.matmul(A, B) xp.tensordot(A, B) # Test no_compute for functions that are fully bespoke for dask def test_asarray_no_compute(xp): with assert_no_compute(): a = xp.arange(10) xp.asarray(a) xp.asarray(a, dtype=np.int16) xp.asarray(a, dtype=a.dtype) xp.asarray(a, copy=True) xp.asarray(a, copy=True, dtype=np.int16) xp.asarray(a, copy=True, dtype=a.dtype) xp.asarray(a, copy=False) xp.asarray(a, copy=False, dtype=a.dtype) @pytest.mark.parametrize("copy", [True, False]) def test_astype_no_compute(xp, copy): with assert_no_compute(): a = xp.arange(10) xp.astype(a, np.int16, copy=copy) xp.astype(a, a.dtype, copy=copy) def test_clip_no_compute(xp): with assert_no_compute(): a = xp.arange(10) xp.clip(a) xp.clip(a, 1) xp.clip(a, 1, 8) @pytest.mark.parametrize("chunks", (5, 10)) def test_sort_argsort_nocompute(xp, chunks): with assert_no_compute(): a = xp.arange(10, chunks=chunks) xp.sort(a) xp.argsort(a) def test_generators_are_lazy(xp): """ Test that generator functions are fully lazy, e.g. that da.ones(n) is not implemented as da.asarray(np.ones(n)) """ size = 100_000_000_000 # 800 GB chunks = size // 10 # 10x 80 GB chunks with assert_no_compute(): xp.zeros(size, chunks=chunks) xp.ones(size, chunks=chunks) xp.empty(size, chunks=chunks) xp.full(size, fill_value=123, chunks=chunks) a = xp.arange(size, chunks=chunks) xp.zeros_like(a) xp.ones_like(a) xp.empty_like(a) xp.full_like(a, fill_value=123) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_chunks(xp, func, axis): """Test that sort and argsort are functionally correct when the array is chunked along the sort axis, e.g. the sort is not just local to each chunk. """ a = da.random.random((10, 10), chunks=(5, 5)) actual = getattr(xp, func)(a, axis=axis) expect = getattr(np, func)(a.compute(), axis=axis) np.testing.assert_array_equal(actual, expect) @pytest.mark.parametrize( "shape,chunks", [ # 3 GiB; 128 MiB per chunk; must rechunk before sorting. # Sort chunks can be 128 MiB each; no need for final rechunk. ((20_000, 20_000), "auto"), # 3 GiB; 128 MiB per chunk; must rechunk before sorting. # Must sort on two 1.5 GiB chunks; benefits from final rechunk. ((2, 2**30 * 3 // 16), "auto"), # 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting. # Surely the user must know what they're doing, so don't # perform the final rechunk. ((2, 2**30 * 3 // 16), (1, -1)), ], ) @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_chunk_size(xp, func, shape, chunks): """ Test that sort and argsort produce reasonably-sized chunks in the output array, even if they had to go through a singular huge one to perform the operation. """ a = da.random.random(shape, chunks=chunks) b = getattr(xp, func)(a) max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize assert ( max_chunk_size <= 128 * 1024 * 1024 # 128 MiB or b.chunks == a.chunks ) @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_meta(xp, func): """Test meta-namespace other than numpy""" typ = type(array_api_strict.asarray(0)) a = da.random.random(10) b = a.map_blocks(array_api_strict.asarray) assert isinstance(b._meta, typ) c = getattr(xp, func)(b) assert isinstance(c._meta, typ) d = c.compute() # Note: np.sort(array_api_strict.asarray(0)) would return a numpy array assert isinstance(d, typ) np.testing.assert_array_equal(d, getattr(np, func)(a.compute())) array-api-compat-1.11.2/tests/test_isdtype.py000066400000000000000000000067771476700770300212320ustar00rootroot00000000000000""" isdtype is not yet tested in the test suite, and it should extend properly to non-spec dtypes """ import pytest from ._helpers import import_, wrapped_libraries # Check the known dtypes by their string names def _spec_dtypes(library): if library == 'torch': # torch does not have unsigned integer dtypes return { 'bool', 'complex64', 'complex128', 'uint8', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64', } else: return { 'bool', 'complex64', 'complex128', 'float32', 'float64', 'int16', 'int32', 'int64', 'int8', 'uint16', 'uint32', 'uint64', 'uint8', } dtype_categories = { 'bool': lambda d: d == 'bool', 'signed integer': lambda d: d.startswith('int'), 'unsigned integer': lambda d: d.startswith('uint'), 'integral': lambda d: dtype_categories['signed integer'](d) or dtype_categories['unsigned integer'](d), 'real floating': lambda d: 'float' in d, 'complex floating': lambda d: d.startswith('complex'), 'numeric': lambda d: dtype_categories['integral'](d) or dtype_categories['real floating'](d) or dtype_categories['complex floating'](d), } def isdtype_(dtype_, kind): # Check a dtype_ string against kind. Note that 'bool' technically has two # meanings here but they are both the same. if kind in dtype_categories: res = dtype_categories[kind](dtype_) else: res = dtype_ == kind assert type(res) is bool # noqa: E721 return res @pytest.mark.parametrize("library", wrapped_libraries) def test_isdtype_spec_dtypes(library): xp = import_(library, wrapper=True) isdtype = xp.isdtype for dtype_ in _spec_dtypes(library): for dtype2_ in _spec_dtypes(library): dtype = getattr(xp, dtype_) dtype2 = getattr(xp, dtype2_) res = isdtype_(dtype_, dtype2_) assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_) for cat in dtype_categories: res = isdtype_(dtype_, cat) assert isdtype(dtype, cat) == res, (dtype_, cat) # Basic tuple testing (the array-api testsuite will be more complete here) for kind1_ in [*_spec_dtypes(library), *dtype_categories]: for kind2_ in [*_spec_dtypes(library), *dtype_categories]: kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_) kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_) kind = (kind1, kind2) res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_) assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_)) additional_dtypes = [ 'float16', 'float128', 'complex256', 'bfloat16', ] @pytest.mark.parametrize("library", wrapped_libraries) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): xp = import_(library, wrapper=True) isdtype = xp.isdtype if not hasattr(xp, dtype_): return # pytest.skip(f"{library} doesn't have dtype {dtype_}") dtype = getattr(xp, dtype_) for cat in dtype_categories: res = isdtype_(dtype_, cat) assert isdtype(dtype, cat) == res, (dtype_, cat) array-api-compat-1.11.2/tests/test_jax.py000066400000000000000000000021231476700770300203100ustar00rootroot00000000000000import jax import jax.numpy as jnp from numpy.testing import assert_equal import pytest from array_api_compat import device, to_device HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" @pytest.mark.parametrize( "func", [ lambda x: jnp.zeros(1, device=device(x)), lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))), lambda x: jnp.full(1, fill_value=0, device=device(x)), pytest.param( lambda x: jnp.asarray([0], device=device(x)), marks=pytest.mark.skipif( not HAS_JAX_0_4_31, reason="asarray() has no device= parameter" ), ), lambda x: to_device(jnp.zeros(1), device(x)), ] ) def test_device_jit(func): # Test work around to https://github.com/jax-ml/jax/issues/26000 # Also test missing to_device() method in JAX < 0.4.31 # when inside jax.jit, even after importing jax.experimental.array_api x = jnp.ones(1) assert_equal(func(x), jnp.asarray([0])) assert_equal(jax.jit(func)(x), jnp.asarray([0])) array-api-compat-1.11.2/tests/test_no_dependencies.py000066400000000000000000000046421476700770300226600ustar00rootroot00000000000000""" Test that array_api_compat has no "hard" dependencies. Libraries like NumPy should only be imported if a numpy array is passed to array_namespace or if array_api_compat.numpy is explicitly imported. We have to test this in a subprocess because these libraries have already been imported from the other tests. """ import sys import subprocess import pytest class Array: # Dummy array namespace that doesn't depend on any array library def __array_namespace__(self, api_version=None): class Namespace: pass return Namespace() def _test_dependency(mod): assert mod not in sys.modules # Run various functions that shouldn't depend on mod and check that they # don't import it. import array_api_compat assert mod not in sys.modules a = Array() # array-api-strict is an example of an array API library that isn't # wrapped by array-api-compat. if "strict" not in mod and mod != "sparse": is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array") assert not is_mod_array(a) assert mod not in sys.modules is_array_api_obj = getattr(array_api_compat, "is_array_api_obj") assert is_array_api_obj(a) assert mod not in sys.modules array_namespace = getattr(array_api_compat, "array_namespace") array_namespace(Array()) assert mod not in sys.modules # TODO: Test that wrapper for library X doesn't depend on wrappers for library # Y (except most array libraries actually do themselves depend on numpy). @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy", "sparse", "array_api_strict"]) def test_numpy_dependency(library): # This import is here because it imports numpy from ._helpers import import_ # This unfortunately won't go through any of the pytest machinery. We # reraise the exception as an AssertionError so that pytest will show it # in a semi-reasonable way # Import (in this process) to make sure 'library' is actually installed and # so that cupy can be skipped. import_(library) try: subprocess.run([sys.executable, '-c', f'''\ from tests.test_no_dependencies import _test_dependency _test_dependency({library!r})'''], check=True, capture_output=True, encoding='utf-8') except subprocess.CalledProcessError as e: print(e.stdout, end='') raise AssertionError(e.stderr) array-api-compat-1.11.2/tests/test_torch.py000066400000000000000000000063621476700770300206560ustar00rootroot00000000000000"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. """ import itertools import pytest import torch from array_api_compat import torch as xp class TestResultType: def test_empty(self): with pytest.raises(ValueError): xp.result_type() def test_one_arg(self): for x in [1, 1.0, 1j, '...', None]: with pytest.raises((ValueError, AttributeError)): xp.result_type(x) for x in [xp.float32, xp.int64, torch.complex64]: assert xp.result_type(x) == x for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: assert xp.result_type(x) == x.dtype def test_two_args(self): # Only include here things "unspecified" in the spec # scalar, tensor or tensor,tensor for x, y in [ (1., 1j), (1j, xp.arange(3)), (True, xp.asarray(3.)), (xp.ones(3) == 1, 1j*xp.ones(3)), ]: assert xp.result_type(x, y) == torch.result_type(x, y) # dtype, scalar for x, y in [ (1j, xp.int64), (True, xp.float64), ]: assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) # dtype, dtype for x, y in [ (xp.bool, xp.complex64) ]: xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) assert xp.result_type(x, y) == torch.result_type(xt, yt) def test_multi_arg(self): torch.set_default_dtype(torch.float32) args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] assert xp.result_type(*args) == torch.float16 args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] assert xp.result_type(*args) == xp.complex64 args = [1, 2, 3j, xp.float64, 4, 5, 6] assert xp.result_type(*args) == xp.complex128 args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] assert xp.result_type(*args) == xp.complex128 i64 = xp.ones(1, dtype=xp.int64) f16 = xp.ones(1, dtype=xp.float16) for i in itertools.permutations([i64, f16, 1.0, 1.0]): assert xp.result_type(*i) == xp.float16, f"{i}" with pytest.raises(ValueError): xp.result_type(1, 2, 3, 4) @pytest.mark.parametrize("default_dt", ['float32', 'float64']) @pytest.mark.parametrize("dtype_a", (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) ) @pytest.mark.parametrize("dtype_b", (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) ) def test_gh_273(self, default_dt, dtype_a, dtype_b): # Regression test for https://github.com/data-apis/array-api-compat/issues/273 try: prev_default = torch.get_default_dtype() default_dtype = getattr(torch, default_dt) torch.set_default_dtype(default_dtype) a = xp.asarray([2, 1], dtype=dtype_a) b = xp.asarray([1, -1], dtype=dtype_b) dtype_1 = xp.result_type(a, b, 1.0) dtype_2 = xp.result_type(b, a, 1.0) assert dtype_1 == dtype_2 finally: torch.set_default_dtype(prev_default) array-api-compat-1.11.2/tests/test_vendoring.py000066400000000000000000000006621476700770300215270ustar00rootroot00000000000000import pytest def test_vendoring_numpy(): from vendor_test import uses_numpy uses_numpy._test_numpy() def test_vendoring_cupy(): pytest.importorskip("cupy") from vendor_test import uses_cupy uses_cupy._test_cupy() def test_vendoring_torch(): from vendor_test import uses_torch uses_torch._test_torch() def test_vendoring_dask(): from vendor_test import uses_dask uses_dask._test_dask() array-api-compat-1.11.2/torch-skips.txt000066400000000000000000000020201476700770300177560ustar00rootroot00000000000000# These tests cause a core dump on CI, so we have to skip them entirely array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] array-api-compat-1.11.2/torch-xfails.txt000066400000000000000000000265611476700770300201330ustar00rootroot00000000000000# Note: see array_api_compat/torch/_aliases.py for links to corresponding # pytorch issues # We cannot wrap the array object # Indexing does not support negative step array_api_tests/test_array_object.py::test_getitem array_api_tests/test_array_object.py::test_setitem # Masking doesn't suport 0 dimensions in the mask array_api_tests/test_array_object.py::test_getitem_masking # Overflow error from large inputs array_api_tests/test_creation_functions.py::test_arange # pytorch linspace bug (should be fixed in torch 2.0) # We cannot wrap the tensor object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] # We cannot wrap the tensor object array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)] # This test is skipped instead of xfailed because it causes core dumps on CI # array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] # inverse trig functions are too inaccurate on CPU array_api_tests/test_operators_and_elementwise_functions.py::test_acos array_api_tests/test_operators_and_elementwise_functions.py::test_atan array_api_tests/test_operators_and_elementwise_functions.py::test_asin # Torch bug for remainder in some cases with large values array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] # unique_all cannot be implemented because torch's unique does not support # returning indices array_api_tests/test_set_functions.py::test_unique_all # unique_inverse incorrectly counts nan values # (https://github.com/pytorch/pytorch/issues/94106) array_api_tests/test_set_functions.py::test_unique_inverse # We cannot add attributes to the tensor object array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] array_api_tests/test_signatures.py::test_array_method_signature[to_device] # We do not attempt to work around special-case differences (most are on # tensor methods which we couldn't fix anyway). array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] # Float correction is not supported by pytorch # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_std array_api_tests/test_statistical_functions.py::test_var # These functions do not yet support complex numbers array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values # 2023.12 support array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[repeat] # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_signatures.py::test_array_method_signature[__and__] array_api_tests/test_signatures.py::test_array_method_signature[__lshift__] array_api_tests/test_signatures.py::test_array_method_signature[__or__] array_api_tests/test_signatures.py::test_array_method_signature[__rshift__] array_api_tests/test_signatures.py::test_array_method_signature[__xor__] # 2024.12 support: binary functions reject python scalar arguments array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_and] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_or] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_xor] array-api-compat-1.11.2/vendor_test/000077500000000000000000000000001476700770300173115ustar00rootroot00000000000000array-api-compat-1.11.2/vendor_test/__init__.py000066400000000000000000000000001476700770300214100ustar00rootroot00000000000000array-api-compat-1.11.2/vendor_test/uses_cupy.py000066400000000000000000000012631476700770300217040ustar00rootroot00000000000000# Basic test that vendoring works from .vendored._compat import ( cupy as cp_compat, is_cupy_array, is_cupy_namespace, ) import cupy as cp def _test_cupy(): a = cp_compat.asarray([1., 2., 3.]) b = cp_compat.arange(3, dtype=cp_compat.float32) # cp.pow does not exist. Update this to use something else if it is added res = cp_compat.pow(a, b) assert res.dtype == cp_compat.float64 == cp.float64 assert isinstance(a, cp.ndarray) assert isinstance(b, cp.ndarray) assert isinstance(res, cp.ndarray) cp.testing.assert_allclose(res, [1., 2., 9.]) assert is_cupy_array(res) assert is_cupy_namespace(cp) and is_cupy_namespace(cp_compat) array-api-compat-1.11.2/vendor_test/uses_dask.py000066400000000000000000000013461476700770300216500ustar00rootroot00000000000000# Basic test that vendoring works from .vendored._compat.dask import array as dask_compat from .vendored._compat import is_dask_array, is_dask_namespace import dask.array as da import numpy as np def _test_dask(): a = dask_compat.asarray([1., 2., 3.]) b = dask_compat.arange(3, dtype=dask_compat.float32) # np.pow does not exist. Update this to use something else if it is added res = dask_compat.pow(a, b) assert res.dtype == dask_compat.float64 == np.float64 assert isinstance(a, da.Array) assert isinstance(b, da.Array) assert isinstance(res, da.Array) np.testing.assert_allclose(res, [1., 2., 9.]) assert is_dask_array(res) assert is_dask_namespace(da) and is_dask_namespace(dask_compat) array-api-compat-1.11.2/vendor_test/uses_numpy.py000066400000000000000000000012741476700770300220760ustar00rootroot00000000000000# Basic test that vendoring works from .vendored._compat import ( is_numpy_array, is_numpy_namespace, numpy as np_compat, ) import numpy as np def _test_numpy(): a = np_compat.asarray([1., 2., 3.]) b = np_compat.arange(3, dtype=np_compat.float32) # np.pow does not exist. Update this to use something else if it is added res = np_compat.pow(a, b) assert res.dtype == np_compat.float64 == np.float64 assert isinstance(a, np.ndarray) assert isinstance(b, np.ndarray) assert isinstance(res, np.ndarray) np.testing.assert_allclose(res, [1., 2., 9.]) assert is_numpy_array(res) assert is_numpy_namespace(np) and is_numpy_namespace(np_compat) array-api-compat-1.11.2/vendor_test/uses_torch.py000066400000000000000000000017061476700770300220450ustar00rootroot00000000000000# Basic test that vendoring works from .vendored._compat import ( is_torch_array, is_torch_namespace, torch as torch_compat, ) import torch def _test_torch(): a = torch_compat.asarray([1., 2., 3.]) b = torch_compat.arange(3, dtype=torch_compat.float64) assert a.dtype == torch_compat.float32 == torch.float32 assert b.dtype == torch_compat.float64 == torch.float64 # torch.expand_dims does not exist. Update this to use something else if it is added res = torch_compat.expand_dims(a, axis=0) assert res.dtype == torch_compat.float32 == torch.float32 assert res.shape == (1, 3) assert isinstance(res.shape, torch.Size) assert isinstance(a, torch.Tensor) assert isinstance(b, torch.Tensor) assert isinstance(res, torch.Tensor) torch.testing.assert_close(res, torch.as_tensor([[1., 2., 3.]])) assert is_torch_array(res) assert is_torch_namespace(torch) and is_torch_namespace(torch_compat) array-api-compat-1.11.2/vendor_test/vendored/000077500000000000000000000000001476700770300211175ustar00rootroot00000000000000array-api-compat-1.11.2/vendor_test/vendored/__init__.py000066400000000000000000000000001476700770300232160ustar00rootroot00000000000000array-api-compat-1.11.2/vendor_test/vendored/_compat000077700000000000000000000000001476700770300264212../../array_api_compat/ustar00rootroot00000000000000