pax_global_header00006660000000000000000000000064151067166560014527gustar00rootroot0000000000000052 comment=9fd1a480f1cdb23b3d28dfea5eadf3d84b6dfc62 jax-ml-ml_dtypes-882eb0f/000077500000000000000000000000001510671665600153355ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/.clang-format000066400000000000000000000001661510671665600177130ustar00rootroot00000000000000BasedOnStyle: Google Language: Cpp PointerBindsToType: true SortIncludes: Never AlignTrailingComments: Kind: Always jax-ml-ml_dtypes-882eb0f/.github/000077500000000000000000000000001510671665600166755ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/.github/dependabot.yml000066400000000000000000000003561510671665600215310ustar00rootroot00000000000000version: 2 updates: - package-ecosystem: pip directory: / schedule: interval: weekly ignore: - dependency-name: "numpy" - package-ecosystem: github-actions directory: / schedule: interval: weekly jax-ml-ml_dtypes-882eb0f/.github/workflows/000077500000000000000000000000001510671665600207325ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/.github/workflows/test.yml000066400000000000000000000103311510671665600224320ustar00rootroot00000000000000name: Test on: # Trigger the workflow on push or pull request, but only on main branch push: branches: - main pull_request: branches: - main permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false - name: Set up Python 3.12 uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: 3.12 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 build: name: ${{ matrix.os }} Python ${{ matrix.python-version }} freethreaded ${{ matrix.freethreaded }} runs-on: ${{ matrix.os }} strategy: matrix: os: ["ubuntu-latest"] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] freethreaded: [false, true] exclude: - python-version: 3.9 freethreaded: true - python-version: 3.10 freethreaded: true - python-version: 3.11 freethreaded: true - python-version: 3.12 freethreaded: true include: - os: macos-14 python-version: "3.12" freethreaded: false - os: macos-14 python-version: "3.14" freethreaded: true - os: windows-2022 python-version: "3.12" freethreaded: false - os: windows-2022 python-version: "3.14" freethreaded: true - os: windows-11-arm python-version: "3.12" freethreaded: false - os: windows-11-arm python-version: "3.14" freethreaded: true steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true persist-credentials: false - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: ${{ matrix.python-version }} freethreaded: ${{ matrix.freethreaded }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install .[dev] - name: Run tests run: | pytest -n auto build-nightly: name: Python 3.14 with nightly numpy runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true persist-credentials: false - name: Set up Python 3.14 uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.14" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install setuptools wheel python -m pip install -U --pre numpy \ -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple python -c "import numpy; print(f'{numpy.__version__=}')" - name: Build ml_dtypes run: | python -m pip install .[dev] --no-build-isolation - name: Run tests run: | pytest -n auto build-oldest-numpy: name: Python 3.9 with oldest supported numpy runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true persist-credentials: false - name: Set up Python 3.9 uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.9" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install --upgrade setuptools wheel - name: Build ml_dtypes run: | python -m pip install .[dev] python -m pip install numpy==1.21.0 # keep in sync with oldest numpy version in pyproject.toml - name: Run tests run: | pytest -n auto jax-ml-ml_dtypes-882eb0f/.github/workflows/wheels.yml000066400000000000000000000073511510671665600227520ustar00rootroot00000000000000name: Build on: workflow_dispatch: {} # allows triggering this workflow manually push: branches: # trigger on commits to main branch - main pull_request: # trigger on pull requests affecting relevant files branches: - main paths: - '**workflows/wheels.yml' - 'pyproject.toml' release: # trigger on published release types: - published permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true jobs: build_wheels: name: Build wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, macos-14, windows-2022] cibw_build: ["cp39-* cp310-* cp311-* cp312-* cp313-* cp313t-* cp314-* cp314t-*"] include: - os: windows-11-arm cibw_build: "cp311-* cp312-* cp313-* cp313t-* cp314-* cp314t-*" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true persist-credentials: false # Used to host cibuildwheel - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.12" - name: Install cibuildwheel run: python -m pip install cibuildwheel==3.1.2 - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: CIBW_ARCHS_LINUX: auto CIBW_ARCHS_MACOS: universal2 CIBW_BUILD: ${{ matrix.cibw_build }} CIBW_ENABLE: cpython-freethreading CIBW_PRERELEASE_PYTHONS: True CIBW_SKIP: "*musllinux* *i686* *win32*" CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist CIBW_TEST_COMMAND: pytest -n auto {project} CIBW_BUILD_VERBOSITY: 1 - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} path: ./wheelhouse/*.whl build_sdist: name: Build source distribution runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true persist-credentials: false - name: Build sdist run: pipx run build --sdist - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: cibw-sdist path: dist/*.tar.gz download_and_list_artifacts: # Helps debug issues like https://github.com/jax-ml/ml_dtypes/issues/196 name: Download and list artifacts needs: [build_sdist, build_wheels] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 with: # unpacks all CIBW artifacts into dist/ pattern: cibw-* path: dist merge-multiple: true - name: List files run: ls -l dist/ upload_pypi: name: Release & Upload to PyPI needs: [build_sdist, build_wheels] runs-on: ubuntu-latest environment: release permissions: id-token: write # Only publish release to PyPI when a github release is created. if: github.event_name == 'release' && github.event.action == 'published' steps: - uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 with: # unpacks all CIBW artifacts into dist/ pattern: cibw-* path: dist merge-multiple: true - name: List files run: ls -l dist/ - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 jax-ml-ml_dtypes-882eb0f/.gitignore000066400000000000000000000004241510671665600173250ustar00rootroot00000000000000# Compiled python modules. *.pyc *.so # Byte-compiled _pycache__/ .cache/ # Poetry, setuptools, PyPI distribution artifacts. /*.egg-info .eggs/ build/ dist/ poetry.lock # Tests .pytest_cache/ # Type checking .pytype/ # Other *.DS_Store # PyCharm .idea # pipenv Pipfile jax-ml-ml_dtypes-882eb0f/.gitmodules000066400000000000000000000002661510671665600175160ustar00rootroot00000000000000[submodule "eigen"] path = eigen url = https://gitlab.com/libeigen/eigen.git [submodule "third_party/eigen"] path = third_party/eigen url = https://gitlab.com/libeigen/eigen.git jax-ml-ml_dtypes-882eb0f/.pre-commit-config.yaml000066400000000000000000000017151510671665600216220ustar00rootroot00000000000000repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # v6.0.0 hooks: - id: check-ast - id: check-merge-conflict - id: check-toml - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - id: debug-statements - repo: https://github.com/google/pyink rev: 1de8968a0d9a9ad4b8f7d9378602d14c5939602d # v24.10.1 hooks: - id: pyink language_version: python3.12 args: [ "--line-length=80", "--preview", "--pyink-indentation=2", "--pyink-use-majority-quotes" ] - repo: https://github.com/astral-sh/ruff-pre-commit rev: 9c89adb347f6b973f4905a4be0051eb2ecf85dea # v0.13.3 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-clang-format rev: 719856d56a62953b8d2839fb9e851f25c3cfeef8 # v21.1.2 hooks: - id: clang-format files: ml_dtypes/ jax-ml-ml_dtypes-882eb0f/.pylintrc000066400000000000000000000346771510671665600172230ustar00rootroot00000000000000# This Pylint rcfile contains a best-effort configuration to uphold the # best-practices and style described in the Google Python style guide: # https://google.github.io/styleguide/pyguide.html # # Its canonical open-source location is: # https://google.github.io/styleguide/pylintrc [MASTER] # Add files or directories to the ignore list. They should be base names, not # paths. ignore=third_party # Add files or directories matching the regex patterns to the ignore list. The # regex matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. persistent=no # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=4 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-allow-list= [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. #enable= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=abstract-method, apply-builtin, arguments-differ, attribute-defined-outside-init, backtick, bad-option-value, basestring-builtin, buffer-builtin, c-extension-no-member, consider-using-enumerate, cmp-builtin, cmp-method, coerce-builtin, coerce-method, delslice-method, div-method, duplicate-code, eq-without-hash, execfile-builtin, file-builtin, filter-builtin-not-iterating, fixme, getslice-method, global-statement, hex-method, idiv-method, implicit-str-concat-in-sequence, import-error, import-self, import-star-module-level, inconsistent-return-statements, input-builtin, intern-builtin, invalid-str-codec, locally-disabled, long-builtin, long-suffix, map-builtin-not-iterating, misplaced-comparison-constant, missing-function-docstring, metaclass-assignment, next-method-called, next-method-defined, no-absolute-import, no-else-break, no-else-continue, no-else-raise, no-else-return, no-init, # added no-member, no-name-in-module, no-self-use, nonzero-method, oct-method, old-division, old-ne-operator, old-octal-literal, old-raise-syntax, parameter-unpacking, print-statement, raising-string, range-builtin-not-iterating, raw_input-builtin, rdiv-method, reduce-builtin, relative-import, reload-builtin, round-builtin, setslice-method, signature-differs, standarderror-builtin, suppressed-message, sys-max-int, too-few-public-methods, too-many-ancestors, too-many-arguments, too-many-boolean-expressions, too-many-branches, too-many-instance-attributes, too-many-locals, too-many-nested-blocks, too-many-public-methods, too-many-return-statements, too-many-statements, trailing-newlines, unichr-builtin, unicode-builtin, unnecessary-pass, unpacking-in-except, useless-else-on-loop, useless-object-inheritance, useless-suppression, using-cmp-argument, wrong-import-order, xrange-builtin, zip-builtin-not-iterating, [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". This option is deprecated # and it will be removed in Pylint 2.0. files-output=no # Tells whether to display a full report or only the messages reports=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [BASIC] # Good variable names which should always be accepted, separated by a comma good-names=main,_ # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl # Regular expression matching correct function names function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ # Regular expression matching correct variable names variable-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct constant names const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Regular expression matching correct attribute names attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ # Regular expression matching correct argument names argument-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct class attribute names class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Regular expression matching correct inline iteration names inlinevar-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct class names class-rgx=^_?[A-Z][a-zA-Z0-9]*$ # Regular expression matching correct module names module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ # Regular expression matching correct method names method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=10 [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= [FORMAT] # Maximum number of characters on a single line. max-line-length=80 # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt # lines made too long by directives to pytype. # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=(?x)( ^\s*(\#\ )??$| ^\s*(from\s+\S+\s+)?import\s+.+$) # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=yes # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check= # Maximum number of lines in a module max-module-lines=99999 # String used as indentation unit. The internal Google style guide mandates 2 # spaces. Google's externaly-published style guide says 4, consistent with # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google # projects (like TensorFlow). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=TODO [STRING] # This flag controls whether inconsistent-quotes generates a warning when the # character used as a quote delimiter is used inconsistently within a module. check-quote-consistency=yes [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging,absl.logging,tensorflow.io.logging [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=4 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=regsub, TERMIOS, Bastion, rexec, sets # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant, absl # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls, class_ # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=StandardError, Exception, BaseException jax-ml-ml_dtypes-882eb0f/.vscode/000077500000000000000000000000001510671665600166765ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/.vscode/settings.json000066400000000000000000000015001510671665600214250ustar00rootroot00000000000000{ "files.insertFinalNewline": true, "files.trimFinalNewlines": true, "files.trimTrailingWhitespace": true, "files.associations": { ".pylintrc": "ini" }, "python.testing.unittestEnabled": false, "python.testing.nosetestsEnabled": false, "python.testing.pytestEnabled": true, "python.linting.pylintUseMinimalCheckers": false, "[python]": { "editor.rulers": [ 80 ], "editor.tabSize": 2, "editor.formatOnSave": true, "editor.detectIndentation": false }, "python.formatting.provider": "black", "python.formatting.blackPath": "pyink", "files.watcherExclude": { "**/.git/**": true }, "files.exclude": { "**/__pycache__": true, "**/.pytest_cache": true, "**/*.egg-info": true } } jax-ml-ml_dtypes-882eb0f/AUTHORS000066400000000000000000000004551510671665600164110ustar00rootroot00000000000000# This is the list of significant contributors to ml_dtypes # # This does not necessarily list everyone who has contributed code, # especially since many employees of one corporation may be contributing. # To see the full list of contributors, see the revision history in # source control. Google LLC jax-ml-ml_dtypes-882eb0f/CHANGELOG.md000066400000000000000000000112641510671665600171520ustar00rootroot00000000000000# Changelog ## [Unreleased] ## [0.5.4] - 2025-11-17 * We now register casts from int2 and int4 to all of the custom float types, except `float6_e2m3fn` and `float8_e8m0fnu`. * Custom floats may now be constructed from Python integers ([#317](https://github.com/jax-ml/ml_dtypes/issues/317)) * Fixed bug in byte-swap operation for custom floats ([#311](https://github.com/jax-ml/ml_dtypes/pull/311)) * Wheels now support Python 3.14 free threading on Windows. ## [0.5.3] - 2025-07-29 * NPY_NEEDS_PYAPI was removed from the dtype flags. This should improve the speed of array operations, but it does mean that values pickled using previous versions of ml_dtypes are incompatible with the current release and should be regenerated with the current release. * Wheels now support Python 3.14. * Wheels now support Windows 11 ARM. ## [0.5.2] - 2025-01-31 * Dropped support for Power wheels again. These turned out to cause problems in our release process. We will consider readding these if NumPy ships Power wheels. * Fixed GCC compilation issues related to ambiguous casts. ## [0.5.1] - 2025-01-06 * Fixed sign bit handling for float4 and float6 types. * Wheels now support Python 3.13 free-threading. * Wheels now support the Power architecture. ## [0.5.0] - 2024-09-13 * Added new 8-bit float types following IEEE 754 convention: `ml_dtypes.float8_e4m3`, `ml_dtypes.float8_e3m4` * Added the 8-bit floating point type `ml_dtypes.float8_e8m0fnu`, which is the OpenCompute MX scale format. * Added new 4-bit and 6-bit float types: `ml_dtypes.float4_e2m1fn`, `ml_dtypes.float6_e2m3fn` and `ml_dtypes.float6_e3m2fn`. * Fix outputs of float `divmod` and `floor_divide` when denominator is zero. ## [0.4.1] - 2024-09-13 * Updates build requirements to use NumPy 2.0 release ## [0.4.0] - 2024-04-1 * Updates `ml_dtypes` for compatibility with future NumPy 2.0 release. * Wheels are built against NumPy 2.0.0rc1. ## [0.4.0b1] - 2024-03-12 * Updates `ml_dtypes` for compatibility with future NumPy 2.0 release. * Wheels for the release candidate are built against NumPy 2.0.0b1. ## [0.3.2] - 2024-01-03 * Fixed spurious invalid value warnings when casting between floating point types on Mac ARM. * Remove `pybind11` build requirement * Update C++ sources for compatibility with NumPy 2.0 ## [0.3.1] - 2023-09-22 * Added support for int4 casting to wider integers such as int8 * Addes support to cast np.float32 and np.float64 into int4 ## [0.3.0] - 2023-09-19 * Dropped support for Python 3.8, following [NEP 29]. * Added support for Python 3.12. * Removed deprecated name `ml_dtypes.float8_e4m3b11`; use `ml_dtypes.float8_e4m3b11fnuz` instead. ## [0.2.0] - 2023-06-06 * New features: * added new 4-bit integer types: `ml_dtypes.int4` and `ml_dtypes.uint4` * Deprecations: * `ml_dtypes.float8_e4m3b11` has been renamed to `ml_dtypes.float8_e4m3b11fnuz` for more consistency with other dtype names. The former name will still be available until version 0.3.0, but will raise a deprecation warning. ## [0.1.0] - 2023-04-11 * Initial release [Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.5.4...HEAD [0.5.4]: https://github.com/jax-ml/ml_dtypes/compare/v0.5.3....v0.5.4 [0.5.3]: https://github.com/jax-ml/ml_dtypes/compare/v0.5.2....v0.5.3 [0.5.2]: https://github.com/jax-ml/ml_dtypes/compare/v0.5.1....v0.5.2 [0.5.1]: https://github.com/jax-ml/ml_dtypes/compare/v0.5.0....v0.5.1 [0.5.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.4.0....v0.5.0 [0.4.1]: https://github.com/jax-ml/ml_dtypes/compare/v0.4.0....v0.4.1 [0.4.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.4.0b1....v0.4.0 [0.4.0b1]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.2...v0.4.0b1 [0.3.2]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.1...v0.3.2 [0.3.1]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...v0.3.1 [0.3.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.2.0...v0.3.0 [0.2.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.1.0...v0.2.0 [0.1.0]: https://github.com/jax-ml/ml_dtypes/releases/tag/v0.1.0 [NEP 29]: https://numpy.org/neps/nep-0029-deprecation_policy.html jax-ml-ml_dtypes-882eb0f/CONTRIBUTING.md000066400000000000000000000021171510671665600175670ustar00rootroot00000000000000# How to Contribute We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code Reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## Community Guidelines This project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). jax-ml-ml_dtypes-882eb0f/LICENSE000066400000000000000000000261361510671665600163520ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. jax-ml-ml_dtypes-882eb0f/LICENSE.eigen000066400000000000000000000405261510671665600174370ustar00rootroot00000000000000Mozilla Public License Version 2.0 ================================== 1. Definitions -------------- 1.1. "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. 1.3. "Contribution" means Covered Software of a particular Contributor. 1.4. "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. "Incompatible With Secondary Licenses" means (a) that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or (b) that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. "Executable Form" means any form of the work other than Source Code Form. 1.7. "Larger Work" means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" means this document. 1.9. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. "Modifications" means any of the following: (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or (b) any new file in Source Code Form that contains any Covered Software. 1.11. "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. "Secondary License" means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. "Source Code Form" means the form of the work preferred for making modifications. 1.14. "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions -------------------------------- 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and (b) under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: (a) for any code that a Contributor has removed from Covered Software; or (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or (c) under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities ------------------- 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: (a) such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and (b) You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation --------------------------------------------------- If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination -------------- 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. ************************************************************************ * * * 6. Disclaimer of Warranty * * ------------------------- * * * * Covered Software is provided under this License on an "as is" * * basis, without warranty of any kind, either expressed, implied, or * * statutory, including, without limitation, warranties that the * * Covered Software is free of defects, merchantable, fit for a * * particular purpose or non-infringing. The entire risk as to the * * quality and performance of the Covered Software is with You. * * Should any Covered Software prove defective in any respect, You * * (not any Contributor) assume the cost of any necessary servicing, * * repair, or correction. This disclaimer of warranty constitutes an * * essential part of this License. No use of any Covered Software is * * authorized under this License except under this disclaimer. * * * ************************************************************************ ************************************************************************ * * * 7. Limitation of Liability * * -------------------------- * * * * Under no circumstances and under no legal theory, whether tort * * (including negligence), contract, or otherwise, shall any * * Contributor, or anyone who distributes Covered Software as * * permitted above, be liable to You for any direct, indirect, * * special, incidental, or consequential damages of any character * * including, without limitation, damages for lost profits, loss of * * goodwill, work stoppage, computer failure or malfunction, or any * * and all other commercial damages or losses, even if such party * * shall have been informed of the possibility of such damages. This * * limitation of liability shall not apply to liability for death or * * personal injury resulting from such party's negligence to the * * extent applicable law prohibits such limitation. Some * * jurisdictions do not allow the exclusion or limitation of * * incidental or consequential damages, so this exclusion and * * limitation may not apply to You. * * * ************************************************************************ 8. Litigation ------------- Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. 9. Miscellaneous ---------------- This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License --------------------------- 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice ------------------------------------------- This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - "Incompatible With Secondary Licenses" Notice --------------------------------------------------------- This Source Code Form is "Incompatible With Secondary Licenses", as defined by the Mozilla Public License, v. 2.0. jax-ml-ml_dtypes-882eb0f/MANIFEST.in000066400000000000000000000002661510671665600170770ustar00rootroot00000000000000recursive-include ml_dtypes *.h include third_party/eigen/Eigen/Core recursive-include third_party/eigen/Eigen/src/plugins *.h recursive-include third_party/eigen/Eigen/src/Core *.h jax-ml-ml_dtypes-882eb0f/README.md000066400000000000000000000171431510671665600166220ustar00rootroot00000000000000# ml_dtypes [![Unittests](https://github.com/jax-ml/ml_dtypes/actions/workflows/test.yml/badge.svg)](https://github.com/jax-ml/ml_dtypes/actions/workflows/test.yml) [![Wheel Build](https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml/badge.svg)](https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml) [![PyPI version](https://badge.fury.io/py/ml_dtypes.svg)](https://badge.fury.io/py/ml_dtypes) `ml_dtypes` is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including: - [`bfloat16`](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format): an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format - 8-bit floating point representations, parameterized by number of exponent and mantissa bits, as well as the bias (if any) and representability of infinity, NaN, and signed zero. * `float8_e3m4` * `float8_e4m3` * `float8_e4m3b11fnuz` * `float8_e4m3fn` * `float8_e4m3fnuz` * `float8_e5m2` * `float8_e5m2fnuz` * `float8_e8m0fnu` - Microscaling (MX) sub-byte floating point representations: * `float4_e2m1fn` * `float6_e2m3fn` * `float6_e3m2fn` - Narrow integer encodings: * `int2` * `int4` * `uint2` * `uint4` See below for specifications of these number formats. ## Installation The `ml_dtypes` package is tested with Python versions 3.9-3.12, and can be installed with the following command: ``` pip install ml_dtypes ``` To test your installation, you can run the following: ``` pip install absl-py pytest pytest --pyargs ml_dtypes ``` To build from source, clone the repository and run: ``` git submodule init git submodule update pip install . ``` ## Example Usage ```python >>> from ml_dtypes import bfloat16 >>> import numpy as np >>> np.zeros(4, dtype=bfloat16) array([0, 0, 0, 0], dtype=bfloat16) ``` Importing `ml_dtypes` also registers the data types with numpy, so that they may be referred to by their string name: ```python >>> np.dtype('bfloat16') dtype(bfloat16) >>> np.dtype('float8_e5m2') dtype(float8_e5m2) ``` ## Specifications of implemented floating point formats ### `bfloat16` A `bfloat16` number is a single-precision float truncated at 16 bits. Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf. ### `float4_e2m1fn` Exponent: 2, Mantissa: 1, bias: 1. Extended range: no inf, no NaN. Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4 bits are unused). NaN representation is undefined. Possible absolute values: [`0`, `0.5`, `1`, `1.5`, `2`, `3`, `4`, `6`] ### `float6_e2m3fn` Exponent: 2, Mantissa: 3, bias: 1. Extended range: no inf, no NaN. Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2 bits are unused). NaN representation is undefined. Possible values range: [`-7.5`; `7.5`] ### `float6_e3m2fn` Exponent: 3, Mantissa: 2, bias: 3. Extended range: no inf, no NaN. Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2 bits are unused). NaN representation is undefined. Possible values range: [`-28`; `28`] ### `float8_e3m4` Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf. ### `float8_e4m3` Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf. ### `float8_e4m3b11fnuz` Exponent: 4, Mantissa: 3, bias: 11. Extended range: no inf, NaN represented by 0b1000'0000. ### `float8_e4m3fn` Exponent: 4, Mantissa: 3, bias: 7. Extended range: no inf, NaN represented by 0bS111'1111. The `fn` suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The `f` indicates it is finite values only. The `n` indicates it includes NaNs, but only at the outer range. ### `float8_e4m3fnuz` 8-bit floating point with 3 bit mantissa. An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. This type has the following characteristics: * bit encoding: S1E4M3 - `0bSEEEEMMM` * exponent bias: 8 * infinities: Not supported * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000` * denormals when exponent is 0 ### `float8_e5m2` Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf. ### `float8_e5m2fnuz` 8-bit floating point with 2 bit mantissa. An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. This type has the following characteristics: * bit encoding: S1E5M2 - `0bSEEEEEMM` * exponent bias: 16 * infinities: Not supported * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000` * denormals when exponent is 0 ### `float8_e8m0fnu` [OpenCompute MX](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) scale format E8M0, which has the following properties: * Unsigned format * 8 exponent bits * Exponent range from -127 to 127 * No zero and infinity * Single NaN value (0xFF). ## `int2`, `int4`, `uint2` and `uint4` 2 and 4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory). NumPy does not support types smaller than a single byte: for example, the distance between adjacent elements in an array (`.strides`) is expressed as an integer number of bytes. Relaxing this restriction would be a considerable engineering project. These types therefore use an unpacked representation, where each element of the array is padded up to a byte in memory. The lower two or four bits of each byte contain the representation of the number, whereas the remaining upper bits are ignored. ## Quirks of low-precision Arithmetic If you're exploring the use of low-precision dtypes in your code, you should be careful to anticipate when the precision loss might lead to surprising results. One example is the behavior of aggregations like `sum`; consider this `bfloat16` summation in NumPy (run with version 1.24.2): ```python >>> from ml_dtypes import bfloat16 >>> import numpy as np >>> rng = np.random.default_rng(seed=0) >>> vals = rng.uniform(size=10000).astype(bfloat16) >>> vals.sum() 256 ``` The true sum should be close to 5000, but numpy returns exactly 256: this is because `bfloat16` does not have the precision to increment `256` by values less than `1`: ```python >>> bfloat16(256) + bfloat16(1) 256 ``` After 256, the next representable value in bfloat16 is 258: ```python >>> np.nextafter(bfloat16(256), bfloat16(np.inf)) 258 ``` For better results you can specify that the accumulation should happen in a higher-precision type like `float32`: ```python >>> vals.sum(dtype='float32').astype(bfloat16) 4992 ``` In contrast to NumPy, projects like [JAX](http://jax.readthedocs.io/) which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically: ```python >>> import jax.numpy as jnp >>> jnp.array(vals).sum() Array(4992, dtype=bfloat16) ``` ## License *This is not an officially supported Google product.* The `ml_dtypes` source code is licensed under the Apache 2.0 license (see [LICENSE](LICENSE)). Pre-compiled wheels are built with the [EIGEN](https://eigen.tuxfamily.org/) project, which is released under the MPL 2.0 license (see [LICENSE.eigen](LICENSE.eigen)). jax-ml-ml_dtypes-882eb0f/RELEASING.md000066400000000000000000000027241510671665600171750ustar00rootroot00000000000000# Releasing ml_dtypes To create a new `ml_dtypes` release, take the following steps: 1. Send a pull request updating the version in `ml_dtypes/__init__.py` to the new version number, as well as updating `CHANGELOG.md` with the changes since the previous release (an example for the 0.2.0 release is [PR #78]). 2. Once this is merged, create the release tag and push it to github. An example from the 0.2.0 release: ``` $ git checkout main $ git pull upstream main # upstream is github.com:jax-ml/ml_dtypes.git $ git log # view commit log & ensure the most recent commit # is your version update PR $ git tag -a v0.2.0 -m "v0.2.0 Release" $ git push upstream v0.2.0 ``` 3. Navigate to https://github.com/jax-ml/ml_dtypes/releases/new, and select this new tag. Copy the change description from `CHANGELOG.md` into the release notes, and click *Publish release*. 4. Publishing the release will trigger the CI jobs configured in `.github/workflows/wheels.yml`, which will build the wheels and source distributions and publish them to PyPI. Navigate to https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml and look for the job associated with this release; monitor it to ensure it finishes green (this will take approximately 30 minutes). 5. Once the build is complete, check https://pypi.org/project/ml-dtypes/ to ensure that the new release is present. [PR #78]: https://github.com/jax-ml/ml_dtypes/pull/78 jax-ml-ml_dtypes-882eb0f/ml_dtypes/000077500000000000000000000000001510671665600173355ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/ml_dtypes/__init__.py000066400000000000000000000044651510671665600214570ustar00rootroot00000000000000# Copyright 2022 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. __version__ = "0.5.4" __all__ = [ "__version__", "bfloat16", "finfo", "float4_e2m1fn", "float6_e2m3fn", "float6_e3m2fn", "float8_e3m4", "float8_e4m3", "float8_e4m3b11fnuz", "float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", "float8_e8m0fnu", "iinfo", "int2", "int4", "uint2", "uint4", ] from typing import Type from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo from ml_dtypes._ml_dtypes_ext import bfloat16 from ml_dtypes._ml_dtypes_ext import float4_e2m1fn from ml_dtypes._ml_dtypes_ext import float6_e2m3fn from ml_dtypes._ml_dtypes_ext import float6_e3m2fn from ml_dtypes._ml_dtypes_ext import float8_e3m4 from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz from ml_dtypes._ml_dtypes_ext import float8_e4m3fn from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz from ml_dtypes._ml_dtypes_ext import float8_e5m2 from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu from ml_dtypes._ml_dtypes_ext import int2 from ml_dtypes._ml_dtypes_ext import int4 from ml_dtypes._ml_dtypes_ext import uint2 from ml_dtypes._ml_dtypes_ext import uint4 import numpy as np bfloat16: Type[np.generic] float4_e2m1fn: Type[np.generic] float6_e2m3fn: Type[np.generic] float6_e3m2fn: Type[np.generic] float8_e3m4: Type[np.generic] float8_e4m3: Type[np.generic] float8_e4m3b11fnuz: Type[np.generic] float8_e4m3fn: Type[np.generic] float8_e4m3fnuz: Type[np.generic] float8_e5m2: Type[np.generic] float8_e5m2fnuz: Type[np.generic] float8_e8m0fnu: Type[np.generic] int2: Type[np.generic] int4: Type[np.generic] uint2: Type[np.generic] uint4: Type[np.generic] del np, Type jax-ml-ml_dtypes-882eb0f/ml_dtypes/_finfo.py000066400000000000000000000545711510671665600211630ustar00rootroot00000000000000# Copyright 2023 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Overload of numpy.finfo to handle dtypes defined in ml_dtypes.""" from ml_dtypes._ml_dtypes_ext import bfloat16 from ml_dtypes._ml_dtypes_ext import float4_e2m1fn from ml_dtypes._ml_dtypes_ext import float6_e2m3fn from ml_dtypes._ml_dtypes_ext import float6_e3m2fn from ml_dtypes._ml_dtypes_ext import float8_e3m4 from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz from ml_dtypes._ml_dtypes_ext import float8_e4m3fn from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz from ml_dtypes._ml_dtypes_ext import float8_e5m2 from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu import numpy as np _bfloat16_dtype = np.dtype(bfloat16) _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) _float6_e2m3fn_dtype = np.dtype(float6_e2m3fn) _float6_e3m2fn_dtype = np.dtype(float6_e3m2fn) _float8_e3m4_dtype = np.dtype(float8_e3m4) _float8_e4m3_dtype = np.dtype(float8_e4m3) _float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz) _float8_e5m2_dtype = np.dtype(float8_e5m2) _float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz) _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) class _Bfloat16MachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-126") self.smallest_normal = bfloat16(smallest_normal) smallest_subnormal = float.fromhex("0x1p-133") self.smallest_subnormal = bfloat16(smallest_subnormal) class _Float4E2m1fnMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p0") self.smallest_normal = float4_e2m1fn(smallest_normal) smallest_subnormal = float.fromhex("0x0.8p0") self.smallest_subnormal = float4_e2m1fn(smallest_subnormal) class _Float6E2m3fnMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p0") self.smallest_normal = float6_e2m3fn(smallest_normal) smallest_subnormal = float.fromhex("0x0.2p0") self.smallest_subnormal = float6_e2m3fn(smallest_subnormal) class _Float6E3m2fnMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-2") self.smallest_normal = float6_e3m2fn(smallest_normal) smallest_subnormal = float.fromhex("0x0.4p-2") self.smallest_subnormal = float6_e3m2fn(smallest_subnormal) class _Float8E3m4MachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-2") self.smallest_normal = float8_e3m4(smallest_normal) smallest_subnormal = float.fromhex("0x0.1p-2") self.smallest_subnormal = float8_e3m4(smallest_subnormal) class _Float8E4m3MachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-6") self.smallest_normal = float8_e4m3(smallest_normal) smallest_subnormal = float.fromhex("0x0.2p-6") self.smallest_subnormal = float8_e4m3(smallest_subnormal) class _Float8E4m3b11fnuzMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-10") self.smallest_normal = float8_e4m3b11fnuz(smallest_normal) smallest_subnormal = float.fromhex("0x1p-13") self.smallest_subnormal = float8_e4m3b11fnuz(smallest_subnormal) class _Float8E4m3fnMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-6") self.smallest_normal = float8_e4m3fn(smallest_normal) smallest_subnormal = float.fromhex("0x1p-9") self.smallest_subnormal = float8_e4m3fn(smallest_subnormal) class _Float8E4m3fnuzMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-7") self.smallest_normal = float8_e4m3fnuz(smallest_normal) smallest_subnormal = float.fromhex("0x1p-10") self.smallest_subnormal = float8_e4m3fnuz(smallest_subnormal) class _Float8E5m2MachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-14") self.smallest_normal = float8_e5m2(smallest_normal) smallest_subnormal = float.fromhex("0x1p-16") self.smallest_subnormal = float8_e5m2(smallest_subnormal) class _Float8E5m2fnuzMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-15") self.smallest_normal = float8_e5m2fnuz(smallest_normal) smallest_subnormal = float.fromhex("0x1p-17") self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal) class _Float8E8m0fnuMachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-127") self.smallest_normal = float8_e8m0fnu(smallest_normal) smallest_subnormal = float.fromhex("0x1p-127") self.smallest_subnormal = float8_e8m0fnu(smallest_subnormal) class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring __doc__ = np.finfo.__doc__ @staticmethod def _bfloat16_finfo(): def float_to_str(f): return "%12.4e" % float(f) tiny = float.fromhex("0x1p-126") resolution = 0.01 eps = float.fromhex("0x1p-7") epsneg = float.fromhex("0x1p-8") max_ = float.fromhex("0x1.FEp127") obj = object.__new__(np.finfo) obj.dtype = _bfloat16_dtype obj.bits = 16 obj.eps = bfloat16(eps) obj.epsneg = bfloat16(epsneg) obj.machep = -7 obj.negep = -8 obj.max = bfloat16(max_) obj.min = bfloat16(-max_) obj.nexp = 8 obj.nmant = 7 obj.iexp = obj.nexp obj.maxexp = 128 obj.minexp = -126 obj.precision = 2 obj.resolution = bfloat16(resolution) # pylint: disable=protected-access obj._machar = _Bfloat16MachArLike() if not hasattr(obj, "tiny"): obj.tiny = bfloat16(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float4_e2m1fn_finfo(): eps = float.fromhex("0x0.8p0") # 0.5 max_ = float.fromhex("0x1.8p2") # 6.0 obj = object.__new__(np.finfo) obj.dtype = _float4_e2m1fn_dtype obj.bits = 4 obj.eps = eps obj.epsneg = eps obj.machep = -1 obj.negep = -1 obj.max = float4_e2m1fn(max_) obj.min = float4_e2m1fn(-max_) obj.nexp = 2 obj.nmant = 1 obj.iexp = obj.nexp obj.maxexp = 3 obj.minexp = 0 obj.precision = 0 obj.resolution = float4_e2m1fn(1.0) # pylint: disable=protected-access obj._machar = _Float4E2m1fnMachArLike() tiny = obj._machar.smallest_normal if not hasattr(obj, "tiny"): obj.tiny = tiny if not hasattr(obj, "smallest_normal"): obj.smallest_normal = tiny obj.smallest_subnormal = obj._machar.smallest_subnormal float_to_str = str obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(obj.max) obj._str_epsneg = float_to_str(obj.epsneg) obj._str_eps = float_to_str(obj.eps) obj._str_resolution = float_to_str(obj.resolution) # pylint: enable=protected-access return obj @staticmethod def _float6_e2m3fn_finfo(): eps = float.fromhex("0x0.2p0") # 0.125 max_ = float.fromhex("0x1.Ep2") # 7.5 obj = object.__new__(np.finfo) obj.dtype = _float6_e2m3fn_dtype obj.bits = 6 obj.eps = eps obj.epsneg = eps obj.machep = -3 obj.negep = -3 obj.max = float6_e2m3fn(max_) obj.min = float6_e2m3fn(-max_) obj.nexp = 2 obj.nmant = 3 obj.iexp = obj.nexp obj.maxexp = 3 obj.minexp = 0 obj.precision = 0 obj.resolution = float6_e2m3fn(1.0) # pylint: disable=protected-access obj._machar = _Float6E2m3fnMachArLike() tiny = obj._machar.smallest_normal if not hasattr(obj, "tiny"): obj.tiny = tiny if not hasattr(obj, "smallest_normal"): obj.smallest_normal = tiny obj.smallest_subnormal = obj._machar.smallest_subnormal float_to_str = str obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(obj.max) obj._str_epsneg = float_to_str(obj.epsneg) obj._str_eps = float_to_str(obj.eps) obj._str_resolution = float_to_str(obj.resolution) # pylint: enable=protected-access return obj @staticmethod def _float6_e3m2fn_finfo(): eps = float.fromhex("0x1p-2") # 0.25 max_ = float.fromhex("0x1.Cp4") # 28 obj = object.__new__(np.finfo) obj.dtype = _float6_e3m2fn_dtype obj.bits = 6 obj.eps = eps obj.epsneg = eps / 2 obj.machep = -2 obj.negep = -3 obj.max = float6_e3m2fn(max_) obj.min = float6_e3m2fn(-max_) obj.nexp = 3 obj.nmant = 2 obj.iexp = obj.nexp obj.maxexp = 5 obj.minexp = -2 obj.precision = 0 obj.resolution = float6_e3m2fn(1.0) # pylint: disable=protected-access obj._machar = _Float6E3m2fnMachArLike() tiny = obj._machar.smallest_normal if not hasattr(obj, "tiny"): obj.tiny = tiny if not hasattr(obj, "smallest_normal"): obj.smallest_normal = tiny obj.smallest_subnormal = obj._machar.smallest_subnormal float_to_str = str obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(obj.max) obj._str_epsneg = float_to_str(obj.epsneg) obj._str_eps = float_to_str(obj.eps) obj._str_resolution = float_to_str(obj.resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e3m4_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-2") # 1/4 min normal resolution = 0.1 eps = float.fromhex("0x1p-4") # 1/16 epsneg = float.fromhex("0x1p-5") # 1/32 max_ = float.fromhex("0x1.Fp3") # 15.5 max normal obj = object.__new__(np.finfo) obj.dtype = _float8_e3m4_dtype obj.bits = 8 obj.eps = float8_e3m4(eps) obj.epsneg = float8_e3m4(epsneg) obj.machep = -4 obj.negep = -5 obj.max = float8_e3m4(max_) obj.min = float8_e3m4(-max_) obj.nexp = 3 obj.nmant = 4 obj.iexp = obj.nexp obj.maxexp = 4 obj.minexp = -2 obj.precision = 1 obj.resolution = float8_e3m4(resolution) # pylint: disable=protected-access obj._machar = _Float8E3m4MachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e3m4(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e4m3_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-6") # 1/64 min normal resolution = 0.1 eps = float.fromhex("0x1p-3") # 1/8 epsneg = float.fromhex("0x1p-4") # 1/16 max_ = float.fromhex("0x1.Ep7") # 240 max normal obj = object.__new__(np.finfo) obj.dtype = _float8_e4m3_dtype obj.bits = 8 obj.eps = float8_e4m3(eps) obj.epsneg = float8_e4m3(epsneg) obj.machep = -3 obj.negep = -4 obj.max = float8_e4m3(max_) obj.min = float8_e4m3(-max_) obj.nexp = 4 obj.nmant = 3 obj.iexp = obj.nexp obj.maxexp = 8 obj.minexp = -6 obj.precision = 1 obj.resolution = float8_e4m3(resolution) # pylint: disable=protected-access obj._machar = _Float8E4m3MachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e4m3(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e4m3b11fnuz_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-10") resolution = 0.1 eps = float.fromhex("0x1p-3") epsneg = float.fromhex("0x1p-4") max_ = float.fromhex("0x1.Ep4") obj = object.__new__(np.finfo) obj.dtype = _float8_e4m3b11fnuz_dtype obj.bits = 8 obj.eps = float8_e4m3b11fnuz(eps) obj.epsneg = float8_e4m3b11fnuz(epsneg) obj.machep = -3 obj.negep = -4 obj.max = float8_e4m3b11fnuz(max_) obj.min = float8_e4m3b11fnuz(-max_) obj.nexp = 4 obj.nmant = 3 obj.iexp = obj.nexp obj.maxexp = 5 obj.minexp = -10 obj.precision = 1 obj.resolution = float8_e4m3b11fnuz(resolution) # pylint: disable=protected-access obj._machar = _Float8E4m3b11fnuzMachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e4m3b11fnuz(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e4m3fn_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-6") resolution = 0.1 eps = float.fromhex("0x1p-3") epsneg = float.fromhex("0x1p-4") max_ = float.fromhex("0x1.Cp8") obj = object.__new__(np.finfo) obj.dtype = _float8_e4m3fn_dtype obj.bits = 8 obj.eps = float8_e4m3fn(eps) obj.epsneg = float8_e4m3fn(epsneg) obj.machep = -3 obj.negep = -4 obj.max = float8_e4m3fn(max_) obj.min = float8_e4m3fn(-max_) obj.nexp = 4 obj.nmant = 3 obj.iexp = obj.nexp obj.maxexp = 9 obj.minexp = -6 obj.precision = 1 obj.resolution = float8_e4m3fn(resolution) # pylint: disable=protected-access obj._machar = _Float8E4m3fnMachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e4m3fn(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e4m3fnuz_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-7") resolution = 0.1 eps = float.fromhex("0x1p-3") epsneg = float.fromhex("0x1p-4") max_ = float.fromhex("0x1.Ep7") obj = object.__new__(np.finfo) obj.dtype = _float8_e4m3fnuz_dtype obj.bits = 8 obj.eps = float8_e4m3fnuz(eps) obj.epsneg = float8_e4m3fnuz(epsneg) obj.machep = -3 obj.negep = -4 obj.max = float8_e4m3fnuz(max_) obj.min = float8_e4m3fnuz(-max_) obj.nexp = 4 obj.nmant = 3 obj.iexp = obj.nexp obj.maxexp = 8 obj.minexp = -7 obj.precision = 1 obj.resolution = float8_e4m3fnuz(resolution) # pylint: disable=protected-access obj._machar = _Float8E4m3fnuzMachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e4m3fnuz(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e5m2_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-14") resolution = 0.1 eps = float.fromhex("0x1p-2") epsneg = float.fromhex("0x1p-3") max_ = float.fromhex("0x1.Cp15") obj = object.__new__(np.finfo) obj.dtype = _float8_e5m2_dtype obj.bits = 8 obj.eps = float8_e5m2(eps) obj.epsneg = float8_e5m2(epsneg) obj.machep = -2 obj.negep = -3 obj.max = float8_e5m2(max_) obj.min = float8_e5m2(-max_) obj.nexp = 5 obj.nmant = 2 obj.iexp = obj.nexp obj.maxexp = 16 obj.minexp = -14 obj.precision = 1 obj.resolution = float8_e5m2(resolution) # pylint: disable=protected-access obj._machar = _Float8E5m2MachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e5m2(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e5m2fnuz_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-15") resolution = 0.1 eps = float.fromhex("0x1p-2") epsneg = float.fromhex("0x1p-3") max_ = float.fromhex("0x1.Cp15") obj = object.__new__(np.finfo) obj.dtype = _float8_e5m2fnuz_dtype obj.bits = 8 obj.eps = float8_e5m2fnuz(eps) obj.epsneg = float8_e5m2fnuz(epsneg) obj.machep = -2 obj.negep = -3 obj.max = float8_e5m2fnuz(max_) obj.min = float8_e5m2fnuz(-max_) obj.nexp = 5 obj.nmant = 2 obj.iexp = obj.nexp obj.maxexp = 16 obj.minexp = -15 obj.precision = 1 obj.resolution = float8_e5m2fnuz(resolution) # pylint: disable=protected-access obj._machar = _Float8E5m2fnuzMachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e5m2fnuz(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj @staticmethod def _float8_e8m0fnu_finfo(): def float_to_str(f): return "%6.2e" % float(f) tiny = float.fromhex("0x1p-127") resolution = 0.1 eps = float.fromhex("0x1p+0") epsneg = float.fromhex("0x1p-1") max_ = float.fromhex("0x1p+127") obj = object.__new__(np.finfo) obj.dtype = _float8_e8m0fnu_dtype obj.bits = 8 obj.eps = float8_e8m0fnu(eps) obj.epsneg = float8_e8m0fnu(epsneg) obj.machep = 0 obj.negep = -1 obj.max = float8_e8m0fnu(max_) obj.min = float8_e8m0fnu(tiny) obj.nexp = 8 obj.nmant = 0 obj.iexp = obj.nexp obj.maxexp = 128 obj.minexp = -127 obj.precision = 1 obj.resolution = float8_e8m0fnu(resolution) # pylint: disable=protected-access obj._machar = _Float8E8m0fnuMachArLike() if not hasattr(obj, "tiny"): obj.tiny = float8_e8m0fnu(tiny) if not hasattr(obj, "smallest_normal"): obj.smallest_normal = obj._machar.smallest_normal obj.smallest_subnormal = obj._machar.smallest_subnormal obj._str_tiny = float_to_str(tiny) obj._str_smallest_normal = float_to_str(tiny) obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) obj._str_max = float_to_str(max_) obj._str_epsneg = float_to_str(epsneg) obj._str_eps = float_to_str(eps) obj._str_resolution = float_to_str(resolution) # pylint: enable=protected-access return obj _finfo_type_map = { _bfloat16_dtype: _bfloat16_finfo, _float4_e2m1fn_dtype: _float4_e2m1fn_finfo, _float6_e2m3fn_dtype: _float6_e2m3fn_finfo, _float6_e3m2fn_dtype: _float6_e3m2fn_finfo, _float8_e3m4_dtype: _float8_e3m4_finfo, _float8_e4m3_dtype: _float8_e4m3_finfo, _float8_e4m3fn_dtype: _float8_e4m3fn_finfo, _float8_e4m3fnuz_dtype: _float8_e4m3fnuz_finfo, _float8_e4m3b11fnuz_dtype: _float8_e4m3b11fnuz_finfo, _float8_e5m2_dtype: _float8_e5m2_finfo, _float8_e5m2fnuz_dtype: _float8_e5m2fnuz_finfo, _float8_e8m0fnu_dtype: _float8_e8m0fnu_finfo, } _finfo_name_map = {t.name: t for t in _finfo_type_map} _finfo_cache = { t: init_fn.__func__() # pytype: disable=attribute-error for t, init_fn in _finfo_type_map.items() } def __new__(cls, dtype): if isinstance(dtype, str): key = cls._finfo_name_map.get(dtype) elif isinstance(dtype, np.dtype): key = dtype else: key = np.dtype(dtype) i = cls._finfo_cache.get(key) if i is not None: return i return super().__new__(cls, dtype) jax-ml-ml_dtypes-882eb0f/ml_dtypes/_iinfo.py000066400000000000000000000037531510671665600211620ustar00rootroot00000000000000# Copyright 2023 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Overload of numpy.iinfo to handle dtypes defined in ml_dtypes.""" from ml_dtypes._ml_dtypes_ext import int2 from ml_dtypes._ml_dtypes_ext import int4 from ml_dtypes._ml_dtypes_ext import uint2 from ml_dtypes._ml_dtypes_ext import uint4 import numpy as np _int2_dtype = np.dtype(int2) _uint2_dtype = np.dtype(uint2) _int4_dtype = np.dtype(int4) _uint4_dtype = np.dtype(uint4) class iinfo: # pylint: disable=invalid-name,missing-class-docstring kind: str bits: int min: int max: int dtype: np.dtype def __init__(self, int_type): if int_type == _int2_dtype: self.dtype = _int2_dtype self.kind = "i" self.bits = 2 self.min = -2 self.max = 1 elif int_type == _uint2_dtype: self.dtype = _uint2_dtype self.kind = "u" self.bits = 2 self.min = 0 self.max = 3 elif int_type == _int4_dtype: self.dtype = _int4_dtype self.kind = "i" self.bits = 4 self.min = -8 self.max = 7 elif int_type == _uint4_dtype: self.dtype = _uint4_dtype self.kind = "u" self.bits = 4 self.min = 0 self.max = 15 else: ii = np.iinfo(int_type) self.dtype = ii.dtype self.kind = ii.kind self.bits = ii.bits self.min = ii.min self.max = ii.max def __repr__(self): return f"iinfo(min={self.min}, max={self.max}, dtype={self.dtype})" def __str__(self): return repr(self) jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/000077500000000000000000000000001510671665600202635ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/common.h000066400000000000000000000104341510671665600217260ustar00rootroot00000000000000/* Copyright 2022 The ml_dtypes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_COMMON_H_ #define ML_DTYPES_COMMON_H_ // Must be included first // clang-format off #include "ml_dtypes/_src/numpy.h" // clang-format on #include #include //NOLINT #include "Eigen/Core" namespace ml_dtypes { struct PyDecrefDeleter { void operator()(PyObject* p) const { Py_DECREF(p); } }; // Safe container for an owned PyObject. On destruction, the reference count of // the contained object will be decremented. using Safe_PyObjectPtr = std::unique_ptr; inline Safe_PyObjectPtr make_safe(PyObject* object) { return Safe_PyObjectPtr(object); } template struct TypeDescriptor { // typedef ... T; // Representation type in memory for NumPy values of type // static int Dtype() { return NPY_...; } // Numpy type number for T. }; template <> struct TypeDescriptor { typedef unsigned char T; static int Dtype() { return NPY_UBYTE; } }; template <> struct TypeDescriptor { // NOLINT typedef unsigned short T; // NOLINT static int Dtype() { return NPY_USHORT; } }; // We register "int", "long", and "long long" types for portability across // Linux, where "int" and "long" are the same type, and Windows, where "long" // and "longlong" are the same type. template <> struct TypeDescriptor { typedef unsigned int T; static int Dtype() { return NPY_UINT; } }; template <> struct TypeDescriptor { // NOLINT typedef unsigned long T; // NOLINT static int Dtype() { return NPY_ULONG; } }; template <> struct TypeDescriptor { // NOLINT typedef unsigned long long T; // NOLINT static int Dtype() { return NPY_ULONGLONG; } }; template <> struct TypeDescriptor { typedef signed char T; static int Dtype() { return NPY_BYTE; } }; template <> struct TypeDescriptor { // NOLINT typedef short T; // NOLINT static int Dtype() { return NPY_SHORT; } }; template <> struct TypeDescriptor { typedef int T; static int Dtype() { return NPY_INT; } }; template <> struct TypeDescriptor { // NOLINT typedef long T; // NOLINT static int Dtype() { return NPY_LONG; } }; template <> struct TypeDescriptor { // NOLINT typedef long long T; // NOLINT static int Dtype() { return NPY_LONGLONG; } }; template <> struct TypeDescriptor { typedef unsigned char T; static int Dtype() { return NPY_BOOL; } }; template <> struct TypeDescriptor { typedef Eigen::half T; static int Dtype() { return NPY_HALF; } }; template <> struct TypeDescriptor { typedef float T; static int Dtype() { return NPY_FLOAT; } }; template <> struct TypeDescriptor { typedef double T; static int Dtype() { return NPY_DOUBLE; } }; template <> struct TypeDescriptor { typedef long double T; static int Dtype() { return NPY_LONGDOUBLE; } }; template <> struct TypeDescriptor> { typedef std::complex T; static int Dtype() { return NPY_CFLOAT; } }; template <> struct TypeDescriptor> { typedef std::complex T; static int Dtype() { return NPY_CDOUBLE; } }; template <> struct TypeDescriptor> { typedef std::complex T; static int Dtype() { return NPY_CLONGDOUBLE; } }; template struct is_complex : std::false_type {}; template struct is_complex> : std::true_type {}; template inline constexpr bool is_complex_v = is_complex::value; } // namespace ml_dtypes #endif // ML_DTYPES_COMMON_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/custom_float.h000066400000000000000000001010051510671665600231300ustar00rootroot00000000000000/* Copyright 2022 The ml_dtypes Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_CUSTOM_FLOAT_H_ #define ML_DTYPES_CUSTOM_FLOAT_H_ // Must be included first // clang-format off #include "ml_dtypes/_src/numpy.h" // NOLINT // clang-format on // Support utilities for adding custom floating-point dtypes to TensorFlow, // such as bfloat16, and float8_*. #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT // Place `` before to avoid a build failure in macOS. #include #include "Eigen/Core" #include "ml_dtypes/_src/common.h" // NOLINT #include "ml_dtypes/_src/ufuncs.h" // NOLINT #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build // Possible this has to do with numpy.h being included before // system headers and in bfloat16.{cc,h}? #if NPY_ABI_VERSION < 0x02000000 #define PyArray_DescrProto PyArray_Descr #endif namespace ml_dtypes { template struct CustomFloatType { static int Dtype() { return npy_type; } // Registered numpy type ID. Global variable populated by the registration // code. Protected by the GIL. static int npy_type; // Pointer to the python type object we are using. This is either a pointer // to type, if we choose to register it, or to the python type // registered by another system into NumPy. static PyObject* type_ptr; static PyType_Spec type_spec; static PyType_Slot type_slots[]; static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; }; template int CustomFloatType::npy_type = NPY_NOTYPE; template PyObject* CustomFloatType::type_ptr = nullptr; template PyArray_DescrProto CustomFloatType::npy_descr_proto; template PyArray_Descr* CustomFloatType::npy_descr = nullptr; // Representation of a Python custom float object. template struct PyCustomFloat { PyObject_HEAD; // Python object header T value; }; // Returns true if 'object' is a PyCustomFloat. template bool PyCustomFloat_Check(PyObject* object) { return PyObject_IsInstance(object, TypeDescriptor::type_ptr); } // Extracts the value of a PyCustomFloat object. template T PyCustomFloat_CustomFloat(PyObject* object) { return reinterpret_cast*>(object)->value; } // Constructs a PyCustomFloat object from PyCustomFloat::T. template Safe_PyObjectPtr PyCustomFloat_FromT(T x) { PyTypeObject* type = reinterpret_cast(TypeDescriptor::type_ptr); Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0)); PyCustomFloat* p = reinterpret_cast*>(ref.get()); if (p) { p->value = x; } return ref; } // Converts a Python object to a reduced float value. Returns true on success, // returns false and reports a Python error on failure. template bool CastToCustomFloat(PyObject* arg, T* output) { if (PyCustomFloat_Check(arg)) { *output = PyCustomFloat_CustomFloat(arg); return true; } if (PyFloat_Check(arg)) { double d = PyFloat_AsDouble(arg); if (PyErr_Occurred()) { return false; } // TODO(phawkins): check for overflow *output = T(d); return true; } if (PyLong_Check(arg)) { long l = PyLong_AsLong(arg); // NOLINT if (PyErr_Occurred()) { return false; } // TODO(phawkins): check for overflow *output = T(static_cast(l)); return true; } if (PyArray_IsScalar(arg, Half)) { Eigen::half f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsScalar(arg, Float)) { float f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsScalar(arg, Double)) { double f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsScalar(arg, LongDouble)) { long double f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsScalar(arg, Integer)) { int64_t i; PyArray_CastScalarToCtype(arg, &i, PyArray_DescrFromType(NPY_INT64)); *output = T(i); return true; } if (PyArray_IsZeroDim(arg)) { Safe_PyObjectPtr ref; PyArrayObject* arr = reinterpret_cast(arg); if (PyArray_TYPE(arr) != TypeDescriptor::Dtype()) { ref = make_safe(PyArray_Cast(arr, TypeDescriptor::Dtype())); if (PyErr_Occurred()) { return false; } arg = ref.get(); arr = reinterpret_cast(arg); } *output = *reinterpret_cast(PyArray_DATA(arr)); return true; } return false; } template bool SafeCastToCustomFloat(PyObject* arg, T* output) { if (PyCustomFloat_Check(arg)) { *output = PyCustomFloat_CustomFloat(arg); return true; } return false; } // Converts a PyReduceFloat into a PyFloat. template PyObject* PyCustomFloat_Float(PyObject* self) { T x = PyCustomFloat_CustomFloat(self); return PyFloat_FromDouble(static_cast(static_cast(x))); } // Converts a PyReduceFloat into a PyInt. template PyObject* PyCustomFloat_Int(PyObject* self) { T x = PyCustomFloat_CustomFloat(self); long y = static_cast(static_cast(x)); // NOLINT return PyLong_FromLong(y); } // Negates a PyCustomFloat. template PyObject* PyCustomFloat_Negative(PyObject* self) { T x = PyCustomFloat_CustomFloat(self); return PyCustomFloat_FromT(-x).release(); } template PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x + y).release(); } return PyArray_Type.tp_as_number->nb_add(a, b); } template PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x - y).release(); } return PyArray_Type.tp_as_number->nb_subtract(a, b); } template PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x * y).release(); } return PyArray_Type.tp_as_number->nb_multiply(a, b); } template PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x / y).release(); } return PyArray_Type.tp_as_number->nb_true_divide(a, b); } // Constructs a new PyCustomFloat. template PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args, PyObject* kwds) { if (kwds && PyDict_Size(kwds)) { PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments"); return nullptr; } Py_ssize_t size = PyTuple_Size(args); if (size != 1) { PyErr_Format(PyExc_TypeError, "expected number as argument to %s constructor", TypeDescriptor::kTypeName); return nullptr; } PyObject* arg = PyTuple_GetItem(args, 0); T value; if (PyCustomFloat_Check(arg)) { Py_INCREF(arg); return arg; } else if (CastToCustomFloat(arg, &value)) { return PyCustomFloat_FromT(value).release(); } else if (PyArray_Check(arg)) { PyArrayObject* arr = reinterpret_cast(arg); if (PyArray_TYPE(arr) != TypeDescriptor::Dtype()) { return PyArray_Cast(arr, TypeDescriptor::Dtype()); } else { Py_INCREF(arg); return arg; } } else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) { // Parse float from string, then cast to T. PyObject* f = PyFloat_FromString(arg); if (CastToCustomFloat(f, &value)) { return PyCustomFloat_FromT(value).release(); } } PyErr_Format(PyExc_TypeError, "expected number, got %s", Py_TYPE(arg)->tp_name); return nullptr; } // Comparisons on PyCustomFloats. template PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) { T x, y; if (!SafeCastToCustomFloat(a, &x) || !SafeCastToCustomFloat(b, &y)) { return PyGenericArrType_Type.tp_richcompare(a, b, op); } bool result; switch (op) { case Py_LT: result = x < y; break; case Py_LE: result = x <= y; break; case Py_EQ: result = x == y; break; case Py_NE: result = x != y; break; case Py_GT: result = x > y; break; case Py_GE: result = x >= y; break; default: PyErr_SetString(PyExc_ValueError, "Invalid op type"); return nullptr; } PyArrayScalar_RETURN_BOOL_FROM_LONG(result); } // Implementation of repr() for PyCustomFloat. template PyObject* PyCustomFloat_Repr(PyObject* self) { T x = reinterpret_cast*>(self)->value; float f = static_cast(x); std::ostringstream s; s << (std::isnan(f) ? std::abs(f) : f); return PyUnicode_FromString(s.str().c_str()); } // Implementation of str() for PyCustomFloat. template PyObject* PyCustomFloat_Str(PyObject* self) { T x = reinterpret_cast*>(self)->value; float f = static_cast(x); std::ostringstream s; s << (std::isnan(f) ? std::abs(f) : f); return PyUnicode_FromString(s.str().c_str()); } // _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to // handle the two possibilities. // NOLINTNEXTLINE(clang-diagnostic-unused-function) inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double), PyObject* self, double value) { return hash_double(self, value); } // NOLINTNEXTLINE(clang-diagnostic-unused-function) inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self, double value) { return hash_double(value); } // Hash function for PyCustomFloat. template Py_hash_t PyCustomFloat_Hash(PyObject* self) { T x = reinterpret_cast*>(self)->value; return HashImpl(&_Py_HashDouble, self, static_cast(x)); } template PyType_Slot CustomFloatType::type_slots[] = { {Py_tp_new, reinterpret_cast(PyCustomFloat_New)}, {Py_tp_repr, reinterpret_cast(PyCustomFloat_Repr)}, {Py_tp_hash, reinterpret_cast(PyCustomFloat_Hash)}, {Py_tp_str, reinterpret_cast(PyCustomFloat_Str)}, {Py_tp_doc, reinterpret_cast(const_cast(TypeDescriptor::kTpDoc))}, {Py_tp_richcompare, reinterpret_cast(PyCustomFloat_RichCompare)}, {Py_nb_add, reinterpret_cast(PyCustomFloat_Add)}, {Py_nb_subtract, reinterpret_cast(PyCustomFloat_Subtract)}, {Py_nb_multiply, reinterpret_cast(PyCustomFloat_Multiply)}, {Py_nb_negative, reinterpret_cast(PyCustomFloat_Negative)}, {Py_nb_int, reinterpret_cast(PyCustomFloat_Int)}, {Py_nb_float, reinterpret_cast(PyCustomFloat_Float)}, {Py_nb_true_divide, reinterpret_cast(PyCustomFloat_TrueDivide)}, {0, nullptr}, }; template PyType_Spec CustomFloatType::type_spec = { /*.name=*/TypeDescriptor::kQualifiedTypeName, /*.basicsize=*/static_cast(sizeof(PyCustomFloat)), /*.itemsize=*/0, /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*.slots=*/CustomFloatType::type_slots, }; // Numpy support template PyArray_ArrFuncs CustomFloatType::arr_funcs; template PyArray_DescrProto GetCustomFloatDescrProto() { return { PyObject_HEAD_INIT(nullptr) /*typeobj=*/nullptr, // Filled in later /*kind=*/TypeDescriptor::kNpyDescrKind, /*type=*/TypeDescriptor::kNpyDescrType, /*byteorder=*/TypeDescriptor::kNpyDescrByteorder, /*flags=*/NPY_USE_SETITEM, /*type_num=*/0, /*elsize=*/sizeof(T), /*alignment=*/alignof(T), /*subarray=*/nullptr, /*fields=*/nullptr, /*names=*/nullptr, /*f=*/&CustomFloatType::arr_funcs, /*metadata=*/nullptr, /*c_metadata=*/nullptr, /*hash=*/-1, // -1 means "not computed yet". }; } // Implementations of NumPy array methods. template PyObject* NPyCustomFloat_GetItem(void* data, void* arr) { T x; memcpy(&x, data, sizeof(T)); return PyFloat_FromDouble(static_cast(x)); } template int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) { T x; if (!CastToCustomFloat(item, &x)) { PyErr_Format(PyExc_TypeError, "expected number, got %s", Py_TYPE(item)->tp_name); return -1; } memcpy(data, &x, sizeof(T)); return 0; } inline void ByteSwap16(void* value) { char* p = reinterpret_cast(value); std::swap(p[0], p[1]); } template int NPyCustomFloat_Compare(const void* a, const void* b, void* arr) { T x; memcpy(&x, a, sizeof(T)); T y; memcpy(&y, b, sizeof(T)); float fy(y); float fx(x); if (fx < fy) { return -1; } if (fy < fx) { return 1; } // NaNs sort to the end. if (!Eigen::numext::isnan(fx) && Eigen::numext::isnan(fy)) { return -1; } if (Eigen::numext::isnan(fx) && !Eigen::numext::isnan(fy)) { return 1; } return 0; } template void NPyCustomFloat_CopySwapN(void* dstv, npy_intp dstride, void* srcv, npy_intp sstride, npy_intp n, int swap, void* arr) { static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Not supported"); char* dst = reinterpret_cast(dstv); char* src = reinterpret_cast(srcv); if (src) { if (swap && sizeof(T) == sizeof(int16_t)) { for (npy_intp i = 0; i < n; i++) { char* r = dst + dstride * i; memcpy(r, src + sstride * i, sizeof(T)); ByteSwap16(r); } } else if (dstride == sizeof(T) && sstride == sizeof(T)) { memcpy(dst, src, n * sizeof(T)); } else { for (npy_intp i = 0; i < n; i++) { memcpy(dst + dstride * i, src + sstride * i, sizeof(T)); } } } else { // In-place swap when src is NULL if (swap && sizeof(T) == sizeof(int16_t)) { for (npy_intp i = 0; i < n; i++) { char* r = dst + dstride * i; ByteSwap16(r); } } } } template void NPyCustomFloat_CopySwap(void* dst, void* src, int swap, void* arr) { static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Not supported"); if (src) { memcpy(dst, src, sizeof(T)); } if (swap && sizeof(T) == sizeof(int16_t)) { ByteSwap16(dst); } } template npy_bool NPyCustomFloat_NonZero(void* data, void* arr) { T x; memcpy(&x, data, sizeof(x)); return x != static_cast(0); } template int NPyCustomFloat_Fill(void* buffer_raw, npy_intp length, void* ignored) { T* const buffer = reinterpret_cast(buffer_raw); const float start(buffer[0]); const float delta = static_cast(buffer[1]) - start; for (npy_intp i = 2; i < length; ++i) { buffer[i] = static_cast(start + i * delta); } return 0; } template void NPyCustomFloat_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2, void* op, npy_intp n, void* arr) { char* c1 = reinterpret_cast(ip1); char* c2 = reinterpret_cast(ip2); float acc = 0.0f; for (npy_intp i = 0; i < n; ++i) { T* const b1 = reinterpret_cast(c1); T* const b2 = reinterpret_cast(c2); acc += static_cast(*b1) * static_cast(*b2); c1 += is1; c2 += is2; } T* out = reinterpret_cast(op); *out = static_cast(acc); } template int NPyCustomFloat_CompareFunc(const void* v1, const void* v2, void* arr) { T b1 = *reinterpret_cast(v1); T b2 = *reinterpret_cast(v2); if (b1 < b2) { return -1; } if (b1 > b2) { return 1; } return 0; } template int NPyCustomFloat_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) { const T* bdata = reinterpret_cast(data); // Start with a max_val of NaN, this results in the first iteration preferring // bdata[0]. float max_val = std::numeric_limits::quiet_NaN(); for (npy_intp i = 0; i < n; ++i) { // This condition is chosen so that NaNs are always considered "max". if (!(static_cast(bdata[i]) <= max_val)) { max_val = static_cast(bdata[i]); *max_ind = i; // NumPy stops at the first NaN. if (Eigen::numext::isnan(max_val)) { break; } } } return 0; } template int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, void* arr) { const T* bdata = reinterpret_cast(data); float min_val = std::numeric_limits::quiet_NaN(); // Start with a min_val of NaN, this results in the first iteration preferring // bdata[0]. for (npy_intp i = 0; i < n; ++i) { // This condition is chosen so that NaNs are always considered "min". if (!(static_cast(bdata[i]) >= min_val)) { min_val = static_cast(bdata[i]); *min_ind = i; // NumPy stops at the first NaN. if (Eigen::numext::isnan(min_val)) { break; } } } return 0; } template float CastToFloat(T value) { if constexpr (is_complex_v) { return CastToFloat(value.real()); } else { return static_cast(value); } } // Performs a NumPy array cast from type 'From' to 'To'. template void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr, void* toarr) { const auto* from = reinterpret_cast::T*>(from_void); auto* to = reinterpret_cast::T*>(to_void); for (npy_intp i = 0; i < n; ++i) { to[i] = static_cast::T>( static_cast(CastToFloat(from[i]))); } } // Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type' // is the NumPy type corresponding to 'OtherT'. template bool RegisterCustomFloatCast(int numpy_type = TypeDescriptor::Dtype()) { PyArray_Descr* descr = PyArray_DescrFromType(numpy_type); if (PyArray_RegisterCastFunc(descr, TypeDescriptor::Dtype(), NPyCast) < 0) { return false; } if (PyArray_RegisterCastFunc(CustomFloatType::npy_descr, numpy_type, NPyCast) < 0) { return false; } return true; } template bool RegisterFloatCasts() { if (!RegisterCustomFloatCast(NPY_HALF)) { return false; } if (!RegisterCustomFloatCast(NPY_FLOAT)) { return false; } if (!RegisterCustomFloatCast(NPY_DOUBLE)) { return false; } if (!RegisterCustomFloatCast(NPY_LONGDOUBLE)) { return false; } if (!RegisterCustomFloatCast(NPY_BOOL)) { return false; } if (!RegisterCustomFloatCast(NPY_UBYTE)) { return false; } if (!RegisterCustomFloatCast(NPY_USHORT)) { // NOLINT return false; } if (!RegisterCustomFloatCast(NPY_UINT)) { return false; } if (!RegisterCustomFloatCast(NPY_ULONG)) { // NOLINT return false; } if (!RegisterCustomFloatCast( // NOLINT NPY_ULONGLONG)) { return false; } if (!RegisterCustomFloatCast(NPY_BYTE)) { return false; } if (!RegisterCustomFloatCast(NPY_SHORT)) { // NOLINT return false; } if (!RegisterCustomFloatCast(NPY_INT)) { return false; } if (!RegisterCustomFloatCast(NPY_LONG)) { // NOLINT return false; } if (!RegisterCustomFloatCast(NPY_LONGLONG)) { // NOLINT return false; } // Following the numpy convention. imag part is dropped when converting to // float. if (!RegisterCustomFloatCast>(NPY_CFLOAT)) { return false; } if (!RegisterCustomFloatCast>(NPY_CDOUBLE)) { return false; } if (!RegisterCustomFloatCast>(NPY_CLONGDOUBLE)) { return false; } // Safe casts from T to other types if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_FLOAT, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_DOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_LONGDOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_CFLOAT, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_CDOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_CLONGDOUBLE, NPY_NOSCALAR) < 0) { return false; } // Safe casts to T from other types if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE), TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE), TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } return true; } template bool RegisterFloatUFuncs(PyObject* numpy) { bool ok = RegisterUFunc, T, T, T>, T>(numpy, "add") && RegisterUFunc, T, T, T>, T>(numpy, "subtract") && RegisterUFunc, T, T, T>, T>(numpy, "multiply") && RegisterUFunc, T, T, T>, T>(numpy, "divide") && RegisterUFunc, T, T, T>, T>(numpy, "logaddexp") && RegisterUFunc, T, T, T>, T>(numpy, "logaddexp2") && RegisterUFunc, T, T>, T>(numpy, "negative") && RegisterUFunc, T, T>, T>(numpy, "positive") && RegisterUFunc, T, T, T>, T>(numpy, "true_divide") && RegisterUFunc, T, T, T>, T>( numpy, "floor_divide") && RegisterUFunc, T, T, T>, T>(numpy, "power") && RegisterUFunc, T, T, T>, T>(numpy, "remainder") && RegisterUFunc, T, T, T>, T>(numpy, "mod") && RegisterUFunc, T, T, T>, T>(numpy, "fmod") && RegisterUFunc, T, T, T, T>, T>(numpy, "divmod") && RegisterUFunc, T, T>, T>(numpy, "absolute") && RegisterUFunc, T, T>, T>(numpy, "fabs") && RegisterUFunc, T, T>, T>(numpy, "rint") && RegisterUFunc, T, T>, T>(numpy, "sign") && RegisterUFunc, T, T, T>, T>(numpy, "heaviside") && RegisterUFunc, T, T>, T>(numpy, "conjugate") && RegisterUFunc, T, T>, T>(numpy, "exp") && RegisterUFunc, T, T>, T>(numpy, "exp2") && RegisterUFunc, T, T>, T>(numpy, "expm1") && RegisterUFunc, T, T>, T>(numpy, "log") && RegisterUFunc, T, T>, T>(numpy, "log2") && RegisterUFunc, T, T>, T>(numpy, "log10") && RegisterUFunc, T, T>, T>(numpy, "log1p") && RegisterUFunc, T, T>, T>(numpy, "sqrt") && RegisterUFunc, T, T>, T>(numpy, "square") && RegisterUFunc, T, T>, T>(numpy, "cbrt") && RegisterUFunc, T, T>, T>(numpy, "reciprocal") && // Trigonometric functions RegisterUFunc, T, T>, T>(numpy, "sin") && RegisterUFunc, T, T>, T>(numpy, "cos") && RegisterUFunc, T, T>, T>(numpy, "tan") && RegisterUFunc, T, T>, T>(numpy, "arcsin") && RegisterUFunc, T, T>, T>(numpy, "arccos") && RegisterUFunc, T, T>, T>(numpy, "arctan") && RegisterUFunc, T, T, T>, T>(numpy, "arctan2") && RegisterUFunc, T, T, T>, T>(numpy, "hypot") && RegisterUFunc, T, T>, T>(numpy, "sinh") && RegisterUFunc, T, T>, T>(numpy, "cosh") && RegisterUFunc, T, T>, T>(numpy, "tanh") && RegisterUFunc, T, T>, T>(numpy, "arcsinh") && RegisterUFunc, T, T>, T>(numpy, "arccosh") && RegisterUFunc, T, T>, T>(numpy, "arctanh") && RegisterUFunc, T, T>, T>(numpy, "deg2rad") && RegisterUFunc, T, T>, T>(numpy, "rad2deg") && // Comparison functions RegisterUFunc, bool, T, T>, T>(numpy, "equal") && RegisterUFunc, bool, T, T>, T>(numpy, "not_equal") && RegisterUFunc, bool, T, T>, T>(numpy, "less") && RegisterUFunc, bool, T, T>, T>(numpy, "greater") && RegisterUFunc, bool, T, T>, T>(numpy, "less_equal") && RegisterUFunc, bool, T, T>, T>(numpy, "greater_equal") && RegisterUFunc, T, T, T>, T>(numpy, "maximum") && RegisterUFunc, T, T, T>, T>(numpy, "minimum") && RegisterUFunc, T, T, T>, T>(numpy, "fmax") && RegisterUFunc, T, T, T>, T>(numpy, "fmin") && RegisterUFunc, bool, T, T>, T>( numpy, "logical_and") && RegisterUFunc, bool, T, T>, T>(numpy, "logical_or") && RegisterUFunc, bool, T, T>, T>( numpy, "logical_xor") && RegisterUFunc, bool, T>, T>(numpy, "logical_not") && // Floating point functions RegisterUFunc, bool, T>, T>(numpy, "isfinite") && RegisterUFunc, bool, T>, T>(numpy, "isinf") && RegisterUFunc, bool, T>, T>(numpy, "isnan") && RegisterUFunc, bool, T>, T>(numpy, "signbit") && RegisterUFunc, T, T, T>, T>(numpy, "copysign") && RegisterUFunc, T, T, T>, T>(numpy, "modf") && RegisterUFunc, T, T, int>, T>(numpy, "ldexp") && RegisterUFunc, T, int, T>, T>(numpy, "frexp") && RegisterUFunc, T, T>, T>(numpy, "floor") && RegisterUFunc, T, T>, T>(numpy, "ceil") && RegisterUFunc, T, T>, T>(numpy, "trunc") && RegisterUFunc, T, T, T>, T>(numpy, "nextafter") && RegisterUFunc, T, T>, T>(numpy, "spacing"); return ok; } template bool RegisterFloatDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass // the base type directly when dropping Python 3.9 support. // TODO(jakevdp): it would be better to inherit from PyNumberArrType or // PyFloatingArrType, but this breaks some assumptions made by NumPy, because // dtype.kind='V' is then interpreted as a 'void' type in some contexts. Safe_PyObjectPtr bases( PyTuple_Pack(1, reinterpret_cast(&PyGenericArrType_Type))); PyObject* type = PyType_FromSpecWithBases(&CustomFloatType::type_spec, bases.get()); if (!type) { return false; } TypeDescriptor::type_ptr = type; Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes")); if (!module) { return false; } if (PyObject_SetAttrString(type, "__module__", module.get()) < 0) { return false; } // Initializes the NumPy descriptor. PyArray_ArrFuncs& arr_funcs = CustomFloatType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomFloat_GetItem; arr_funcs.setitem = NPyCustomFloat_SetItem; arr_funcs.compare = NPyCustomFloat_Compare; arr_funcs.copyswapn = NPyCustomFloat_CopySwapN; arr_funcs.copyswap = NPyCustomFloat_CopySwap; arr_funcs.nonzero = NPyCustomFloat_NonZero; arr_funcs.fill = NPyCustomFloat_Fill; arr_funcs.dotfunc = NPyCustomFloat_DotFunc; arr_funcs.compare = NPyCustomFloat_CompareFunc; arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc; arr_funcs.argmin = NPyCustomFloat_ArgMinFunc; // This is messy, but that's because the NumPy 2.0 API transition is messy. // Before 2.0, NumPy assumes we'll keep the descriptor passed in to // RegisterDataType alive, because it stores its pointer. // After 2.0, the proto and descriptor types diverge, and NumPy allocates // and manages the lifetime of the descriptor itself. PyArray_DescrProto& descr_proto = CustomFloatType::npy_descr_proto; descr_proto = GetCustomFloatDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); if (TypeDescriptor::npy_type < 0) { return false; } // TODO(phawkins): We intentionally leak the pointer to the descriptor. // Implement a better module destructor to handle this. CustomFloatType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); Safe_PyObjectPtr typeDict_obj = make_safe(PyObject_GetAttrString(numpy, "sctypeDict")); if (!typeDict_obj) return false; // Add the type object to `numpy.typeDict`: that makes // `numpy.dtype(type_name)` work. if (PyDict_SetItemString(typeDict_obj.get(), TypeDescriptor::kTypeName, TypeDescriptor::type_ptr) < 0) { return false; } // Support dtype(type_name) if (PyObject_SetAttrString( TypeDescriptor::type_ptr, "dtype", reinterpret_cast(CustomFloatType::npy_descr)) < 0) { return false; } return RegisterFloatCasts() && RegisterFloatUFuncs(numpy); } } // namespace ml_dtypes #if NPY_ABI_VERSION < 0x02000000 #undef PyArray_DescrProto #endif #endif // ML_DTYPES_CUSTOM_FLOAT_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/dtypes.cc000066400000000000000000000467371510671665600221230ustar00rootroot00000000000000/* Copyright 2017 The ml_dtypes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Enable cmath defines on Windows #define _USE_MATH_DEFINES // Must be included first // clang-format off #include "ml_dtypes/_src/numpy.h" //NOLINT // clang-format on #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT // Place `` before to avoid a build failure in macOS. #include #include "Eigen/Core" #include "ml_dtypes/_src/custom_float.h" #include "ml_dtypes/_src/intn_numpy.h" #include "ml_dtypes/include/float8.h" #include "ml_dtypes/include/intn.h" #include "ml_dtypes/include/mxfloat.h" namespace ml_dtypes { using bfloat16 = Eigen::bfloat16; template <> struct TypeDescriptor : CustomFloatType { typedef bfloat16 T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "bfloat16"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.bfloat16"; static constexpr const char* kTpDoc = "bfloat16 floating-point values"; // We must register bfloat16 with a kind other than "f", because numpy // considers two types with the same kind and size to be equal, but // float16 != bfloat16. // The downside of this is that NumPy scalar promotion does not work with // bfloat16 values. static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'E'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e3m4 T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e3m4"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e3m4"; static constexpr const char* kTpDoc = "float8_e3m4 floating-point values"; // Set e3m4 kind as Void since kind=f (float) with itemsize=1 is used by e5m2 static constexpr char kNpyDescrKind = 'V'; // Void static constexpr char kNpyDescrType = '3'; static constexpr char kNpyDescrByteorder = '='; // Native byte order }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e4m3 T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e4m3"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3"; static constexpr const char* kTpDoc = "float8_e4m3 floating-point values"; // Set e4m3 kind as Void since kind=f (float) with itemsize=1 is used by e5m2 static constexpr char kNpyDescrKind = 'V'; // Void static constexpr char kNpyDescrType = '7'; // '4' is reserved for e4m3fn static constexpr char kNpyDescrByteorder = '='; // Native byte order }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e4m3b11fnuz T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e4m3b11fnuz"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3b11fnuz"; static constexpr const char* kTpDoc = "float8_e4m3b11fnuz floating-point values"; // We must register float8_e4m3b11fnuz with a kind other than "f", because // numpy considers two types with the same kind and size to be equal, and we // expect multiple 1 byte floating point types. // The downside of this is that NumPy scalar promotion does not work with // float8_e4m3b11fnuz values. static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'L'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e4m3fn T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e4m3fn"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3fn"; static constexpr const char* kTpDoc = "float8_e4m3fn floating-point values"; // We must register float8_e4m3fn with a unique kind, because numpy // considers two types with the same kind and size to be equal. // The downside of this is that NumPy scalar promotion does not work with // float8 values. Using 'V' to mirror bfloat16 vs float16. static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = '4'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e4m3fnuz T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e4m3fnuz"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3fnuz"; static constexpr const char* kTpDoc = "float8_e4m3fnuz floating-point values"; static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'G'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e5m2 T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e5m2"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e5m2"; static constexpr const char* kTpDoc = "float8_e5m2 floating-point values"; // Treating e5m2 as the natural "float" type since it is IEEE-754 compliant. static constexpr char kNpyDescrKind = 'f'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = '5'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e5m2fnuz T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e5m2fnuz"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e5m2fnuz"; static constexpr const char* kTpDoc = "float8_e5m2fnuz floating-point values"; static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'C'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float6_e2m3fn T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float6_e2m3fn"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float6_e2m3fn"; static constexpr const char* kTpDoc = "float6_e2m3fn floating-point values"; static constexpr char kNpyDescrKind = 'V'; static constexpr char kNpyDescrType = '8'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float6_e3m2fn T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float6_e3m2fn"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float6_e3m2fn"; static constexpr const char* kTpDoc = "float6_e3m2fn floating-point values"; static constexpr char kNpyDescrKind = 'V'; static constexpr char kNpyDescrType = '9'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float4_e2m1fn T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float4_e2m1fn"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float4_e2m1fn"; static constexpr const char* kTpDoc = "float4_e2m1fn floating-point values"; static constexpr char kNpyDescrKind = 'V'; static constexpr char kNpyDescrType = '0'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : CustomFloatType { typedef float8_e8m0fnu T; static constexpr bool is_floating = true; static constexpr bool is_integral = false; static constexpr const char* kTypeName = "float8_e8m0fnu"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e8m0fnu"; static constexpr const char* kTpDoc = "float8_e8m0fnu floating-point values"; static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'W'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : IntNTypeDescriptor { typedef int2 T; static constexpr bool is_floating = false; static constexpr bool is_integral = true; static constexpr const char* kTypeName = "int2"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.int2"; static constexpr const char* kTpDoc = "int2 integer values"; static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'c'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : IntNTypeDescriptor { typedef uint2 T; static constexpr bool is_floating = false; static constexpr bool is_integral = true; static constexpr const char* kTypeName = "uint2"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.uint2"; static constexpr const char* kTpDoc = "uint2 integer values"; static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'C'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : IntNTypeDescriptor { typedef int4 T; static constexpr bool is_floating = false; static constexpr bool is_integral = true; static constexpr const char* kTypeName = "int4"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.int4"; static constexpr const char* kTpDoc = "int4 integer values"; static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'a'; static constexpr char kNpyDescrByteorder = '='; }; template <> struct TypeDescriptor : IntNTypeDescriptor { typedef uint4 T; static constexpr bool is_floating = false; static constexpr bool is_integral = true; static constexpr const char* kTypeName = "uint4"; static constexpr const char* kQualifiedTypeName = "ml_dtypes.uint4"; static constexpr const char* kTpDoc = "uint4 integer values"; static constexpr char kNpyDescrKind = 'V'; // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. static constexpr char kNpyDescrType = 'A'; static constexpr char kNpyDescrByteorder = '='; }; namespace { // Performs a NumPy array cast from type 'From' to 'To' via `Via`. template void PyCast(void* from_void, void* to_void, npy_intp n, void* fromarr, void* toarr) { const auto* from = static_cast(from_void); auto* to = static_cast(to_void); for (npy_intp i = 0; i < n; ++i) { to[i] = static_cast(static_cast(from[i])); } } template bool RegisterTwoWayCustomCast() { int nptype1 = TypeDescriptor::npy_type; int nptype2 = TypeDescriptor::npy_type; PyArray_Descr* descr1 = PyArray_DescrFromType(nptype1); if (PyArray_RegisterCastFunc(descr1, nptype2, PyCast) < 0) { return false; } PyArray_Descr* descr2 = PyArray_DescrFromType(nptype2); if (PyArray_RegisterCastFunc(descr2, nptype1, PyCast) < 0) { return false; } return true; } template bool RegisterOneWayCustomCast() { int nptype1 = TypeDescriptor::npy_type; int nptype2 = TypeDescriptor::npy_type; PyArray_Descr* descr1 = PyArray_DescrFromType(nptype1); if (PyArray_RegisterCastFunc(descr1, nptype2, PyCast) < 0) { return false; } return true; } // Register two-way floating point casts between the first and the other types. template bool RegisterTwoWayFloatCasts() { return true; } template bool RegisterTwoWayFloatCasts() { return RegisterTwoWayCustomCast() && RegisterTwoWayFloatCasts(); } // Register two-way floating point casts between all pairs of types. template bool RegisterAllFloatCasts() { return true; } template bool RegisterAllFloatCasts() { return RegisterTwoWayFloatCasts() && RegisterAllFloatCasts(); } // Initialize type attribute in the module object. template bool InitModuleType(PyObject* obj, const char* name) { return PyObject_SetAttrString( obj, name, reinterpret_cast(TypeDescriptor::type_ptr)) >= 0; } } // namespace // Initializes the module. bool Initialize() { ml_dtypes::ImportNumpy(); import_umath1(false); Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy")); if (!numpy_str) { return false; } Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get())); if (!numpy) { return false; } if (!RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get()) || !RegisterFloatDtype(numpy.get())) { return false; } if (!RegisterFloatDtype(numpy.get())) { return false; } if (!RegisterIntNDtype(numpy.get()) || !RegisterIntNDtype(numpy.get()) || !RegisterIntNDtype(numpy.get()) || !RegisterIntNDtype(numpy.get())) { return false; } // Register casts between pairs of custom float dtypes. bool success = RegisterAllFloatCasts(); // Only registering to/from BF16 and FP32 for float8_e8m0fnu. success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); success &= RegisterOneWayCustomCast(); success &= RegisterOneWayCustomCast(); // Int -> float casts. success &= RegisterTwoWayFloatCasts(); success &= RegisterTwoWayFloatCasts(); success &= RegisterTwoWayFloatCasts(); // int4 -> float6_e2m3fn is not safe and we only register safe casts. success &= RegisterTwoWayFloatCasts(); // uint4 -> float6_e2m3fn is not safe and we only register safe casts. return success; } static PyModuleDef module_def = { PyModuleDef_HEAD_INIT, "_ml_dtypes_ext", }; // TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+. // Just use PyMODINIT_FUNC after dropping Python 3.8 support. #if defined(WIN32) || defined(_WIN32) #define EXPORT_SYMBOL __declspec(dllexport) #else #define EXPORT_SYMBOL __attribute__((visibility("default"))) #endif extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def)); if (!m) { return nullptr; } if (!Initialize()) { if (!PyErr_Occurred()) { PyErr_SetString(PyExc_RuntimeError, "cannot load _ml_dtypes_ext module."); } return nullptr; } if (!InitModuleType(m.get(), "float4_e2m1fn") || !InitModuleType(m.get(), "float6_e2m3fn") || !InitModuleType(m.get(), "float6_e3m2fn") || !InitModuleType(m.get(), "float8_e3m4") || !InitModuleType(m.get(), "float8_e4m3") || !InitModuleType(m.get(), "float8_e4m3b11fnuz") || !InitModuleType(m.get(), "float8_e4m3fn") || !InitModuleType(m.get(), "float8_e4m3fnuz") || !InitModuleType(m.get(), "float8_e5m2") || !InitModuleType(m.get(), "float8_e5m2fnuz") || !InitModuleType(m.get(), "float8_e8m0fnu") || !InitModuleType(m.get(), "bfloat16") || !InitModuleType(m.get(), "int2") || !InitModuleType(m.get(), "int4") || !InitModuleType(m.get(), "uint2") || !InitModuleType(m.get(), "uint4")) { return nullptr; } #ifdef Py_GIL_DISABLED PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED); #endif return m.release(); } } // namespace ml_dtypes jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/intn_numpy.h000066400000000000000000000636571510671665600226550ustar00rootroot00000000000000/* Copyright 2023 The ml_dtypes Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_INT4_NUMPY_H_ #define ML_DTYPES_INT4_NUMPY_H_ #include #include // Must be included first // clang-format off #include "ml_dtypes/_src/numpy.h" // clang-format on #include "Eigen/Core" #include "ml_dtypes/_src/common.h" // NOLINT #include "ml_dtypes/_src/ufuncs.h" // NOLINT #include "ml_dtypes/include/intn.h" #if NPY_ABI_VERSION < 0x02000000 #define PyArray_DescrProto PyArray_Descr #endif namespace ml_dtypes { constexpr char kOutOfRange[] = "out of range value cannot be converted to int4"; template struct IntNTypeDescriptor { static int Dtype() { return npy_type; } // Registered numpy type ID. Global variable populated by the registration // code. Protected by the GIL. static int npy_type; // Pointer to the python type object we are using. This is either a pointer // to type, if we choose to register it, or to the python type // registered by another system into NumPy. static PyObject* type_ptr; static PyType_Spec type_spec; static PyType_Slot type_slots[]; static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; }; template int IntNTypeDescriptor::npy_type = NPY_NOTYPE; template PyObject* IntNTypeDescriptor::type_ptr = nullptr; template PyArray_DescrProto IntNTypeDescriptor::npy_descr_proto; template PyArray_Descr* IntNTypeDescriptor::npy_descr = nullptr; // Representation of a Python custom integer object. template struct PyIntN { PyObject_HEAD; // Python object header T value; }; // Returns true if 'object' is a PyIntN. template bool PyIntN_Check(PyObject* object) { return PyObject_IsInstance(object, TypeDescriptor::type_ptr); } // Extracts the value of a PyIntN object. template T PyIntN_Value_Unchecked(PyObject* object) { return reinterpret_cast*>(object)->value; } template bool PyIntN_Value(PyObject* arg, T* output) { if (PyIntN_Check(arg)) { *output = PyIntN_Value_Unchecked(arg); return true; } return false; } // Constructs a PyIntN object from PyIntN::T. template Safe_PyObjectPtr PyIntN_FromValue(T x) { PyTypeObject* type = reinterpret_cast(TypeDescriptor::type_ptr); Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0)); PyIntN* p = reinterpret_cast*>(ref.get()); if (p) { p->value = x; } return ref; } // Converts a Python object to a reduced integer value. Returns true on success, // returns false and reports a Python error on failure. template bool CastToIntN(PyObject* arg, T* output) { if (PyIntN_Check(arg)) { *output = PyIntN_Value_Unchecked(arg); return true; } if (PyFloat_Check(arg)) { double d = PyFloat_AsDouble(arg); if (PyErr_Occurred()) { return false; } if (std::isnan(d)) { PyErr_SetString(PyExc_ValueError, "cannot convert float NaN to integer"); return false; } if (std::isinf(d)) { PyErr_SetString(PyExc_OverflowError, "cannot convert float infinity to integer"); return false; } if (d < static_cast(T::lowest()) || d > static_cast(T::highest())) { PyErr_SetString(PyExc_OverflowError, kOutOfRange); return false; } *output = T(d); return true; } if (PyLong_Check(arg)) { long l = PyLong_AsLong(arg); // NOLINT if (PyErr_Occurred()) { return false; } *output = T(l); return true; } if (PyArray_IsScalar(arg, Integer)) { int64_t v; PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64)); if (!(std::numeric_limits::min() <= v && v <= std::numeric_limits::max())) { PyErr_SetString(PyExc_OverflowError, kOutOfRange); return false; } *output = T(v); return true; } auto floating_conversion = [&](auto type) -> bool { decltype(type) f; PyArray_ScalarAsCtype(arg, &f); if (!(std::numeric_limits::min() <= f && f <= std::numeric_limits::max())) { PyErr_SetString(PyExc_OverflowError, kOutOfRange); return false; } *output = T(static_cast<::int8_t>(f)); return true; }; if (PyArray_IsScalar(arg, Half)) { return floating_conversion(Eigen::half{}); } if (PyArray_IsScalar(arg, Float)) { return floating_conversion(float{}); } if (PyArray_IsScalar(arg, Double)) { return floating_conversion(double{}); } if (PyArray_IsScalar(arg, LongDouble)) { using ld = long double; return floating_conversion(ld{}); } return false; } // Constructs a new PyIntN. template PyObject* PyIntN_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { if (kwds && PyDict_Size(kwds)) { PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments"); return nullptr; } Py_ssize_t size = PyTuple_Size(args); if (size != 1) { PyErr_Format(PyExc_TypeError, "expected number as argument to %s constructor", TypeDescriptor::kTypeName); return nullptr; } PyObject* arg = PyTuple_GetItem(args, 0); T value; if (PyIntN_Check(arg)) { Py_INCREF(arg); return arg; } else if (CastToIntN(arg, &value)) { return PyIntN_FromValue(value).release(); } else if (PyArray_Check(arg)) { PyArrayObject* arr = reinterpret_cast(arg); if (PyArray_TYPE(arr) != TypeDescriptor::Dtype()) { return PyArray_Cast(arr, TypeDescriptor::Dtype()); } else { Py_INCREF(arg); return arg; } } else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) { // Parse float from string, then cast to T. PyObject* f = PyLong_FromUnicodeObject(arg, /*base=*/0); if (PyErr_Occurred()) { return nullptr; } if (CastToIntN(f, &value)) { return PyIntN_FromValue(value).release(); } } if (PyErr_Occurred()) { return nullptr; } PyErr_Format(PyExc_TypeError, "expected number, got %s", Py_TYPE(arg)->tp_name); return nullptr; } template PyObject* PyIntN_nb_float(PyObject* self) { T x = PyIntN_Value_Unchecked(self); return PyFloat_FromDouble(static_cast(x)); } template PyObject* PyIntN_nb_int(PyObject* self) { T x = PyIntN_Value_Unchecked(self); return PyLong_FromLong(static_cast(x)); // NOLINT } template PyObject* PyIntN_nb_negative(PyObject* self) { T x = PyIntN_Value_Unchecked(self); return PyIntN_FromValue(-x).release(); } template PyObject* PyIntN_nb_positive(PyObject* self) { T x = PyIntN_Value_Unchecked(self); return PyIntN_FromValue(x).release(); } template PyObject* PyIntN_nb_add(PyObject* a, PyObject* b) { T x, y; if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x + y).release(); } return PyArray_Type.tp_as_number->nb_add(a, b); } template PyObject* PyIntN_nb_subtract(PyObject* a, PyObject* b) { T x, y; if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x - y).release(); } return PyArray_Type.tp_as_number->nb_subtract(a, b); } template PyObject* PyIntN_nb_multiply(PyObject* a, PyObject* b) { T x, y; if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x * y).release(); } return PyArray_Type.tp_as_number->nb_multiply(a, b); } template PyObject* PyIntN_nb_remainder(PyObject* a, PyObject* b) { T x, y; if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { if (y == 0) { PyErr_SetString(PyExc_ZeroDivisionError, "division by zero"); return nullptr; } T v = x % y; if (v != 0 && ((v < 0) != (y < 0))) { v = v + y; } return PyIntN_FromValue(v).release(); } return PyArray_Type.tp_as_number->nb_remainder(a, b); } template PyObject* PyIntN_nb_floor_divide(PyObject* a, PyObject* b) { T x, y; if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { if (y == 0) { PyErr_SetString(PyExc_ZeroDivisionError, "division by zero"); return nullptr; } T v = x / y; if (((x > 0) != (y > 0)) && x % y != 0) { v = v - T(1); } return PyIntN_FromValue(v).release(); } return PyArray_Type.tp_as_number->nb_floor_divide(a, b); } // Implementation of repr() for PyIntN. template PyObject* PyIntN_Repr(PyObject* self) { T x = PyIntN_Value_Unchecked(self); std::string s = x.ToString(); return PyUnicode_FromString(s.c_str()); } // Implementation of str() for PyIntN. template PyObject* PyIntN_Str(PyObject* self) { T x = PyIntN_Value_Unchecked(self); std::string s = x.ToString(); return PyUnicode_FromString(s.c_str()); } // Hash function for PyIntN. template Py_hash_t PyIntN_Hash(PyObject* self) { T x = PyIntN_Value_Unchecked(self); // Hash functions must not return -1. return static_cast(x) == -1 ? static_cast(-2) : static_cast(x); } // Comparisons on PyIntNs. template PyObject* PyIntN_RichCompare(PyObject* a, PyObject* b, int op) { T x, y; if (!PyIntN_Value(a, &x) || !PyIntN_Value(b, &y)) { return PyGenericArrType_Type.tp_richcompare(a, b, op); } bool result; switch (op) { case Py_LT: result = x < y; break; case Py_LE: result = x <= y; break; case Py_EQ: result = x == y; break; case Py_NE: result = x != y; break; case Py_GT: result = x > y; break; case Py_GE: result = x >= y; break; default: PyErr_SetString(PyExc_ValueError, "Invalid op type"); return nullptr; } PyArrayScalar_RETURN_BOOL_FROM_LONG(result); } template PyType_Slot IntNTypeDescriptor::type_slots[] = { {Py_tp_new, reinterpret_cast(PyIntN_tp_new)}, {Py_tp_repr, reinterpret_cast(PyIntN_Repr)}, {Py_tp_hash, reinterpret_cast(PyIntN_Hash)}, {Py_tp_str, reinterpret_cast(PyIntN_Str)}, {Py_tp_doc, reinterpret_cast(const_cast(TypeDescriptor::kTpDoc))}, {Py_tp_richcompare, reinterpret_cast(PyIntN_RichCompare)}, {Py_nb_add, reinterpret_cast(PyIntN_nb_add)}, {Py_nb_subtract, reinterpret_cast(PyIntN_nb_subtract)}, {Py_nb_multiply, reinterpret_cast(PyIntN_nb_multiply)}, {Py_nb_remainder, reinterpret_cast(PyIntN_nb_remainder)}, {Py_nb_negative, reinterpret_cast(PyIntN_nb_negative)}, {Py_nb_positive, reinterpret_cast(PyIntN_nb_positive)}, {Py_nb_int, reinterpret_cast(PyIntN_nb_int)}, {Py_nb_float, reinterpret_cast(PyIntN_nb_float)}, {Py_nb_floor_divide, reinterpret_cast(PyIntN_nb_floor_divide)}, {0, nullptr}, }; template PyType_Spec IntNTypeDescriptor::type_spec = { /*.name=*/TypeDescriptor::kQualifiedTypeName, /*.basicsize=*/static_cast(sizeof(PyIntN)), /*.itemsize=*/0, /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*.slots=*/IntNTypeDescriptor::type_slots, }; // Numpy support template PyArray_ArrFuncs IntNTypeDescriptor::arr_funcs; template PyArray_DescrProto GetIntNDescrProto() { return { PyObject_HEAD_INIT(nullptr) /*typeobj=*/nullptr, // Filled in later /*kind=*/TypeDescriptor::kNpyDescrKind, /*type=*/TypeDescriptor::kNpyDescrType, /*byteorder=*/TypeDescriptor::kNpyDescrByteorder, /*flags=*/NPY_USE_SETITEM, /*type_num=*/0, /*elsize=*/sizeof(T), /*alignment=*/alignof(T), /*subarray=*/nullptr, /*fields=*/nullptr, /*names=*/nullptr, /*f=*/&IntNTypeDescriptor::arr_funcs, /*metadata=*/nullptr, /*c_metadata=*/nullptr, /*hash=*/-1, // -1 means "not computed yet". }; } // Implementations of NumPy array methods. template PyObject* NPyIntN_GetItem(void* data, void* arr) { T x; memcpy(&x, data, sizeof(T)); return PyLong_FromLong(static_cast(x)); } template int NPyIntN_SetItem(PyObject* item, void* data, void* arr) { T x; if (!CastToIntN(item, &x)) { if (PyErr_Occurred()) { return -1; } PyErr_Format(PyExc_TypeError, "expected number, got %s", Py_TYPE(item)->tp_name); return -1; } memcpy(data, &x, sizeof(T)); return 0; } template int NPyIntN_Compare(const void* a, const void* b, void* arr) { T x; memcpy(&x, a, sizeof(T)); T y; memcpy(&y, b, sizeof(T)); int fy(y); int fx(x); if (fx < fy) { return -1; } if (fy < fx) { return 1; } return 0; } template void NPyIntN_CopySwapN(void* dstv, npy_intp dstride, void* srcv, npy_intp sstride, npy_intp n, int swap, void* arr) { char* dst = reinterpret_cast(dstv); char* src = reinterpret_cast(srcv); if (src) { if (dstride == sizeof(T) && sstride == sizeof(T)) { memcpy(dst, src, n * sizeof(T)); } else { for (npy_intp i = 0; i < n; i++) { memcpy(dst + dstride * i, src + sstride * i, sizeof(T)); } } } // Note: No byte swapping needed for 8-bit integer types } template void NPyIntN_CopySwap(void* dst, void* src, int swap, void* arr) { if (src) { memcpy(dst, src, sizeof(T)); } // Note: No byte swapping needed for 8-bit integer types } template npy_bool NPyIntN_NonZero(void* data, void* arr) { T x; memcpy(&x, data, sizeof(x)); return x != static_cast(0); } template int NPyIntN_Fill(void* buffer_raw, npy_intp length, void* ignored) { T* const buffer = reinterpret_cast(buffer_raw); const int start(buffer[0]); const int delta = static_cast(buffer[1]) - start; for (npy_intp i = 2; i < length; ++i) { buffer[i] = static_cast(start + i * delta); } return 0; } template void NPyIntN_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2, void* op, npy_intp n, void* arr) { char* c1 = reinterpret_cast(ip1); char* c2 = reinterpret_cast(ip2); int acc = 0; for (npy_intp i = 0; i < n; ++i) { T* const b1 = reinterpret_cast(c1); T* const b2 = reinterpret_cast(c2); acc += static_cast(*b1) * static_cast(*b2); c1 += is1; c2 += is2; } T* out = reinterpret_cast(op); *out = static_cast(acc); } template int NPyIntN_CompareFunc(const void* v1, const void* v2, void* arr) { T b1 = *reinterpret_cast(v1); T b2 = *reinterpret_cast(v2); if (b1 < b2) { return -1; } if (b1 > b2) { return 1; } return 0; } template int NPyIntN_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) { const T* bdata = reinterpret_cast(data); // Start with a max_val of INT_MIN, this results in the first iteration // preferring bdata[0]. int max_val = std::numeric_limits::lowest(); for (npy_intp i = 0; i < n; ++i) { if (static_cast(bdata[i]) > max_val) { max_val = static_cast(bdata[i]); *max_ind = i; } } return 0; } template int NPyIntN_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, void* arr) { const T* bdata = reinterpret_cast(data); int min_val = std::numeric_limits::max(); // Start with a min_val of INT_MAX, this results in the first iteration // preferring bdata[0]. for (npy_intp i = 0; i < n; ++i) { if (static_cast(bdata[i]) < min_val) { min_val = static_cast(bdata[i]); *min_ind = i; } } return 0; } template int CastToInt(T value) { if constexpr (is_complex_v) { return CastToInt(value.real()); } else { static_assert(std::numeric_limits::is_specialized); if constexpr (!std::numeric_limits::is_integer) { if (std::isnan(value) || std::isinf(value) || value < std::numeric_limits::lowest() || value > std::numeric_limits::max()) { return 0; } } return static_cast(value); } } // Performs a NumPy array cast from type 'From' to 'To'. template void IntegerCast(void* from_void, void* to_void, npy_intp n, void* fromarr, void* toarr) { const auto* from = reinterpret_cast::T*>(from_void); auto* to = reinterpret_cast::T*>(to_void); for (npy_intp i = 0; i < n; ++i) { to[i] = static_cast::T>( static_cast(CastToInt(from[i]))); } } // Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type' // is the NumPy type corresponding to 'OtherT'. template bool RegisterCustomIntCast(int numpy_type = TypeDescriptor::Dtype()) { PyArray_Descr* descr = PyArray_DescrFromType(numpy_type); if (PyArray_RegisterCastFunc(descr, TypeDescriptor::Dtype(), IntegerCast) < 0) { return false; } if (PyArray_RegisterCastFunc(IntNTypeDescriptor::npy_descr, numpy_type, IntegerCast) < 0) { return false; } return true; } template bool RegisterIntNCasts() { if (!RegisterCustomIntCast(NPY_HALF)) { return false; } if (!RegisterCustomIntCast(NPY_FLOAT)) { return false; } if (!RegisterCustomIntCast(NPY_DOUBLE)) { return false; } if (!RegisterCustomIntCast(NPY_LONGDOUBLE)) { return false; } if (!RegisterCustomIntCast(NPY_BOOL)) { return false; } if (!RegisterCustomIntCast(NPY_UBYTE)) { return false; } if (!RegisterCustomIntCast(NPY_USHORT)) { // NOLINT return false; } if (!RegisterCustomIntCast(NPY_UINT)) { return false; } if (!RegisterCustomIntCast(NPY_ULONG)) { // NOLINT return false; } if (!RegisterCustomIntCast( // NOLINT NPY_ULONGLONG)) { return false; } if (!RegisterCustomIntCast(NPY_BYTE)) { return false; } if (!RegisterCustomIntCast(NPY_SHORT)) { // NOLINT return false; } if (!RegisterCustomIntCast(NPY_INT)) { return false; } if (!RegisterCustomIntCast(NPY_LONG)) { // NOLINT return false; } if (!RegisterCustomIntCast(NPY_LONGLONG)) { // NOLINT return false; } // Following the numpy convention. imag part is dropped when converting to // float. if (!RegisterCustomIntCast>(NPY_CFLOAT)) { return false; } if (!RegisterCustomIntCast>(NPY_CDOUBLE)) { return false; } if (!RegisterCustomIntCast>(NPY_CLONGDOUBLE)) { return false; } // Safe casts from T to other types if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_INT8, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_INT16, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_INT32, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_INT64, NPY_NOSCALAR) < 0) { return false; } if (!std::numeric_limits::is_signed) { if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_UINT8, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_UINT16, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_UINT32, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_UINT64, NPY_NOSCALAR) < 0) { return false; } } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_HALF, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_FLOAT, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_DOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_LONGDOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_CFLOAT, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_CDOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(TypeDescriptor::npy_descr, NPY_CLONGDOUBLE, NPY_NOSCALAR) < 0) { return false; } // Safe casts to T from other types if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } return true; } template bool RegisterIntNUFuncs(PyObject* numpy) { bool ok = RegisterUFunc, T, T, T>, T>(numpy, "add") && RegisterUFunc, T, T, T>, T>(numpy, "subtract") && RegisterUFunc, T, T, T>, T>(numpy, "multiply") && RegisterUFunc, T, T, T>, T>( numpy, "floor_divide") && RegisterUFunc, T, T, T>, T>(numpy, "remainder"); return ok; } template bool RegisterIntNDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass // the base type directly when dropping Python 3.9 support. // TODO(jakevdp): it would be better to inherit from PyNumberArrType or // PyIntegerArrType, but this breaks some assumptions made by NumPy, because // dtype.kind='V' is then interpreted as a 'void' type in some contexts. Safe_PyObjectPtr bases( PyTuple_Pack(1, reinterpret_cast(&PyGenericArrType_Type))); PyObject* type = PyType_FromSpecWithBases(&IntNTypeDescriptor::type_spec, bases.get()); if (!type) { return false; } TypeDescriptor::type_ptr = type; Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes")); if (!module) { return false; } if (PyObject_SetAttrString(TypeDescriptor::type_ptr, "__module__", module.get()) < 0) { return false; } // Initializes the NumPy descriptor. PyArray_ArrFuncs& arr_funcs = IntNTypeDescriptor::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyIntN_GetItem; arr_funcs.setitem = NPyIntN_SetItem; arr_funcs.compare = NPyIntN_Compare; arr_funcs.copyswapn = NPyIntN_CopySwapN; arr_funcs.copyswap = NPyIntN_CopySwap; arr_funcs.nonzero = NPyIntN_NonZero; arr_funcs.fill = NPyIntN_Fill; arr_funcs.dotfunc = NPyIntN_DotFunc; arr_funcs.compare = NPyIntN_CompareFunc; arr_funcs.argmax = NPyIntN_ArgMaxFunc; arr_funcs.argmin = NPyIntN_ArgMinFunc; // This is messy, but that's because the NumPy 2.0 API transition is messy. // Before 2.0, NumPy assumes we'll keep the descriptor passed in to // RegisterDataType alive, because it stores its pointer. // After 2.0, the proto and descriptor types diverge, and NumPy allocates // and manages the lifetime of the descriptor itself. PyArray_DescrProto& descr_proto = IntNTypeDescriptor::npy_descr_proto; descr_proto = GetIntNDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); if (TypeDescriptor::npy_type < 0) { return false; } // TODO(phawkins): We intentionally leak the pointer to the descriptor. // Implement a better module destructor to handle this. IntNTypeDescriptor::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); Safe_PyObjectPtr typeDict_obj = make_safe(PyObject_GetAttrString(numpy, "sctypeDict")); if (!typeDict_obj) return false; // Add the type object to `numpy.typeDict`: that makes // `numpy.dtype(type_name)` work. if (PyDict_SetItemString(typeDict_obj.get(), TypeDescriptor::kTypeName, TypeDescriptor::type_ptr) < 0) { return false; } // Support dtype(type_name) if (PyObject_SetAttrString( TypeDescriptor::type_ptr, "dtype", reinterpret_cast(IntNTypeDescriptor::npy_descr)) < 0) { return false; } return RegisterIntNCasts() && RegisterIntNUFuncs(numpy); } } // namespace ml_dtypes #if NPY_ABI_VERSION < 0x02000000 #undef PyArray_DescrProto #endif #endif // ML_DTYPES_INT4_NUMPY_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/numpy.cc000066400000000000000000000016251510671665600217460ustar00rootroot00000000000000/* Copyright 2022 The ml_dtypes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // We define the PY_ARRAY_UNIQUE_SYMBOL in this .cc file and provide an // ImportNumpy function to populate it. #define ML_DTYPES_IMPORT_NUMPY #include "ml_dtypes/_src/numpy.h" namespace ml_dtypes { void ImportNumpy() { import_array1(); } } // namespace ml_dtypes jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/numpy.h000066400000000000000000000030511510671665600216030ustar00rootroot00000000000000/* Copyright 2022 The ml_dtypes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES__NUMPY_H_ #define ML_DTYPES__NUMPY_H_ #ifdef PyArray_Type #error "Numpy cannot be included before numpy.h." #endif // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // We import_array in the ml_dtypes init function only. #define PY_ARRAY_UNIQUE_SYMBOL _ml_dtypes_numpy_api #ifndef ML_DTYPES_IMPORT_NUMPY #define NO_IMPORT_ARRAY #endif // Place `` before to avoid build failure in macOS. #include #include #include "numpy/arrayobject.h" #include "numpy/arrayscalars.h" #include "numpy/ufuncobject.h" namespace ml_dtypes { // Import numpy. This wrapper function exists so that the // PY_ARRAY_UNIQUE_SYMBOL can be safely defined in a .cc file to // avoid weird linking issues. Should be called only from our // module initialization function. void ImportNumpy(); } // namespace ml_dtypes #endif // ML_DTYPES__NUMPY_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/_src/ufuncs.h000066400000000000000000000456231510671665600217510ustar00rootroot00000000000000/* Copyright 2022 The ml_dtypes Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_UFUNCS_H_ #define ML_DTYPES_UFUNCS_H_ // Must be included first // clang-format off #include "ml_dtypes/_src/numpy.h" // clang-format on #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include "ml_dtypes/_src/common.h" // NOLINT // Some versions of MSVC define a "copysign" macro which wreaks havoc. #if defined(_MSC_VER) && defined(copysign) #undef copysign #endif namespace ml_dtypes { template struct UFunc { static std::vector Types() { return {TypeDescriptor::Dtype()..., TypeDescriptor::Dtype()}; } static constexpr int kInputArity = sizeof...(InTypes); template static void CallImpl(std::index_sequence, char** args, const npy_intp* dimensions, const npy_intp* steps, void* data) { std::array inputs = {args[Is]...}; char* o = args[kInputArity]; for (npy_intp k = 0; k < *dimensions; k++) { *reinterpret_cast(o) = Functor()(*reinterpret_cast(inputs[Is])...); ([&]() { inputs[Is] += steps[Is]; }(), ...); o += steps[kInputArity]; } } static void Call(char** args, const npy_intp* dimensions, const npy_intp* steps, void* data) { return CallImpl(std::index_sequence_for(), args, dimensions, steps, data); } }; template struct UFunc2 { static std::vector Types() { return { TypeDescriptor::Dtype()..., TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), }; } static constexpr int kInputArity = sizeof...(InTypes); template static void CallImpl(std::index_sequence, char** args, const npy_intp* dimensions, const npy_intp* steps, void* data) { std::array inputs = {args[Is]...}; char* o0 = args[kInputArity]; char* o1 = args[kInputArity + 1]; for (npy_intp k = 0; k < *dimensions; k++) { std::tie(*reinterpret_cast(o0), *reinterpret_cast(o1)) = Functor()(*reinterpret_cast(inputs[Is])...); ([&]() { inputs[Is] += steps[Is]; }(), ...); o0 += steps[kInputArity]; o1 += steps[kInputArity + 1]; } } static void Call(char** args, const npy_intp* dimensions, const npy_intp* steps, void* data) { return CallImpl(std::index_sequence_for(), args, dimensions, steps, data); } }; template bool RegisterUFunc(PyObject* numpy, const char* name) { std::vector types = UFuncT::Types(); PyUFuncGenericFunction fn = reinterpret_cast(UFuncT::Call); Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name)); if (!ufunc_obj) { return false; } PyUFuncObject* ufunc = reinterpret_cast(ufunc_obj.get()); if (static_cast(types.size()) != ufunc->nargs) { PyErr_Format(PyExc_AssertionError, "ufunc %s takes %d arguments, loop takes %lu", name, ufunc->nargs, types.size()); return false; } if (PyUFunc_RegisterLoopForType(ufunc, TypeDescriptor::Dtype(), fn, const_cast(types.data()), nullptr) < 0) { return false; } return true; } namespace ufuncs { template struct Add { T operator()(T a, T b) { return a + b; } }; template struct Subtract { T operator()(T a, T b) { return a - b; } }; template struct Multiply { T operator()(T a, T b) { return a * b; } }; template struct TrueDivide { T operator()(T a, T b) { return a / b; } }; static std::pair divmod_impl(float a, float b) { if (b == 0.0f) { float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); if (std::isnan(a) || (a == 0.0f)) { return {nan, nan}; } else { return {std::signbit(a) == std::signbit(b) ? inf : -inf, nan}; } } float mod = std::fmod(a, b); float div = (a - mod) / b; if (mod != 0.0f) { if ((b < 0.0f) != (mod < 0.0f)) { mod += b; div -= 1.0f; } } else { mod = std::copysign(0.0f, b); } float floordiv; if (div != 0.0f) { floordiv = std::floor(div); if (div - floordiv > 0.5f) { floordiv += 1.0f; } } else { floordiv = std::copysign(0.0f, a / b); } return {floordiv, mod}; } template struct Divmod { std::pair operator()(T a, T b) { float c, d; std::tie(c, d) = divmod_impl(static_cast(a), static_cast(b)); return {T(c), T(d)}; } }; template struct FloorDivide { template ::is_integral, bool> = true> T operator()(T x, T y) { if (y == T(0)) { PyErr_WarnEx(PyExc_RuntimeWarning, "divide by zero encountered in floor_divide", 1); return T(0); } T v = x / y; if (((x > 0) != (y > 0)) && x % y != 0) { v = v - T(1); } return v; } template ::is_floating, bool> = true> T operator()(T a, T b) { return T(divmod_impl(static_cast(a), static_cast(b)).first); } }; template struct Remainder { template ::is_integral, bool> = true> T operator()(T x, T y) { if (y == 0) { PyErr_WarnEx(PyExc_RuntimeWarning, "divide by zero encountered in remainder", 1); return T(0); } T v = x % y; if (v != 0 && ((v < 0) != (y < 0))) { v = v + y; } return v; } template ::is_floating, bool> = true> T operator()(T a, T b) { return T(divmod_impl(static_cast(a), static_cast(b)).second); } }; template struct Fmod { T operator()(T a, T b) { return T(std::fmod(static_cast(a), static_cast(b))); } }; template struct Negative { T operator()(T a) { return -a; } }; template struct Positive { T operator()(T a) { return a; } }; template struct Power { T operator()(T a, T b) { return T(std::pow(static_cast(a), static_cast(b))); } }; template struct Abs { T operator()(T a) { return Eigen::numext::abs(a); } }; template struct Cbrt { T operator()(T a) { return T(std::cbrt(static_cast(a))); } }; template struct Ceil { T operator()(T a) { return T(std::ceil(static_cast(a))); } }; // Helper struct for getting a bit representation provided a byte size. template struct GetUnsignedInteger; template <> struct GetUnsignedInteger<1> { using type = uint8_t; }; template <> struct GetUnsignedInteger<2> { using type = uint16_t; }; template using BitsType = typename GetUnsignedInteger::type; template std::pair, BitsType> SignAndMagnitude(T x) { const BitsType x_bits = Eigen::numext::bit_cast>(x); // Unsigned floating point format (e.g. E8M0) => no sign bit (zero by // default). if constexpr (!std::numeric_limits::is_signed) { return {BitsType(0), x_bits}; } // For types that represent NaN by -0, (i.e. *fnuz), abs(x) remains -0 without // flipping the sign. Therefore, we need to explicitly check the // most-significant bit. // For types without NaNs (i.e. mxfloat), use xor to keep the sign bit, which // may be not the most-significant bit. constexpr BitsType kSignMask = BitsType(1) << (sizeof(BitsType) * CHAR_BIT - 1); constexpr bool has_nan = std::numeric_limits::has_quiet_NaN; const BitsType x_abs_bits = Eigen::numext::bit_cast>(Eigen::numext::abs(x)); return {has_nan ? x_bits & kSignMask : x_bits ^ x_abs_bits, x_abs_bits}; } template struct CopySign { T operator()(T a, T b) { // Unsigned floating point format => no change. if constexpr (!std::numeric_limits::is_signed) { return a; } auto [a_sign, a_abs_bits] = SignAndMagnitude(a); auto [b_sign, b_abs_bits] = SignAndMagnitude(b); BitsType rep = a_abs_bits | b_sign; return Eigen::numext::bit_cast(rep); } }; template struct Exp { T operator()(T a) { return T(std::exp(static_cast(a))); } }; template struct Exp2 { T operator()(T a) { return T(std::exp2(static_cast(a))); } }; template struct Expm1 { T operator()(T a) { return T(std::expm1(static_cast(a))); } }; template struct Floor { T operator()(T a) { return T(std::floor(static_cast(a))); } }; template struct Frexp { std::pair operator()(T a) { int exp; float f = std::frexp(static_cast(a), &exp); return {T(f), exp}; } }; template struct Heaviside { T operator()(T x, T h0) { if (Eigen::numext::isnan(x)) { return x; } auto [sign_x, abs_x] = SignAndMagnitude(x); // x == 0 if (abs_x == 0) { return h0; } return sign_x ? T(0.0f) : T(1.0f); } }; template struct Conjugate { T operator()(T a) { return a; } }; template struct IsFinite { bool operator()(T a) { return Eigen::numext::isfinite(a); } }; template struct IsInf { bool operator()(T a) { return Eigen::numext::isinf(a); } }; template struct IsNan { bool operator()(T a) { return Eigen::numext::isnan(a); } }; template struct Ldexp { T operator()(T a, int exp) { return T(std::ldexp(static_cast(a), exp)); } }; template struct Log { T operator()(T a) { return T(std::log(static_cast(a))); } }; template struct Log2 { T operator()(T a) { return T(std::log2(static_cast(a))); } }; template struct Log10 { T operator()(T a) { return T(std::log10(static_cast(a))); } }; template struct Log1p { T operator()(T a) { return T(std::log1p(static_cast(a))); } }; template struct LogAddExp { T operator()(T bx, T by) { float x = static_cast(bx); float y = static_cast(by); if (x == y) { // Handles infinities of the same sign. return T(x + std::log(2.0f)); } float out = std::numeric_limits::quiet_NaN(); if (x > y) { out = x + std::log1p(std::exp(y - x)); } else if (x < y) { out = y + std::log1p(std::exp(x - y)); } return T(out); } }; template struct LogAddExp2 { T operator()(T bx, T by) { float x = static_cast(bx); float y = static_cast(by); if (x == y) { // Handles infinities of the same sign. return T(x + 1.0f); } float out = std::numeric_limits::quiet_NaN(); if (x > y) { out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f); } else if (x < y) { out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f); } return T(out); } }; template struct Modf { std::pair operator()(T a) { float integral; float f = std::modf(static_cast(a), &integral); return {T(f), T(integral)}; } }; template struct Reciprocal { T operator()(T a) { return T(1.f / static_cast(a)); } }; template struct Rint { T operator()(T a) { return T(std::rint(static_cast(a))); } }; template struct Sign { T operator()(T a) { if (Eigen::numext::isnan(a)) { return a; } auto [sign_a, abs_a] = SignAndMagnitude(a); if (abs_a == 0) { return a; } return sign_a ? T(-1) : T(1); } }; template struct SignBit { bool operator()(T a) { auto [sign_a, abs_a] = SignAndMagnitude(a); return sign_a; } }; template struct Sqrt { T operator()(T a) { return T(std::sqrt(static_cast(a))); } }; template struct Square { T operator()(T a) { float f(a); return T(f * f); } }; template struct Trunc { T operator()(T a) { return T(std::trunc(static_cast(a))); } }; // Trigonometric functions template struct Sin { T operator()(T a) { return T(std::sin(static_cast(a))); } }; template struct Cos { T operator()(T a) { return T(std::cos(static_cast(a))); } }; template struct Tan { T operator()(T a) { return T(std::tan(static_cast(a))); } }; template struct Arcsin { T operator()(T a) { return T(std::asin(static_cast(a))); } }; template struct Arccos { T operator()(T a) { return T(std::acos(static_cast(a))); } }; template struct Arctan { T operator()(T a) { return T(std::atan(static_cast(a))); } }; template struct Arctan2 { T operator()(T a, T b) { return T(std::atan2(static_cast(a), static_cast(b))); } }; template struct Hypot { T operator()(T a, T b) { return T(std::hypot(static_cast(a), static_cast(b))); } }; template struct Sinh { T operator()(T a) { return T(std::sinh(static_cast(a))); } }; template struct Cosh { T operator()(T a) { return T(std::cosh(static_cast(a))); } }; template struct Tanh { T operator()(T a) { return T(std::tanh(static_cast(a))); } }; template struct Arcsinh { T operator()(T a) { return T(std::asinh(static_cast(a))); } }; template struct Arccosh { T operator()(T a) { return T(std::acosh(static_cast(a))); } }; template struct Arctanh { T operator()(T a) { return T(std::atanh(static_cast(a))); } }; template struct Deg2rad { T operator()(T a) { static constexpr float radians_per_degree = M_PI / 180.0f; return T(static_cast(a) * radians_per_degree); } }; template struct Rad2deg { T operator()(T a) { static constexpr float degrees_per_radian = 180.0f / M_PI; return T(static_cast(a) * degrees_per_radian); } }; template struct Eq { npy_bool operator()(T a, T b) { return a == b; } }; template struct Ne { npy_bool operator()(T a, T b) { return a != b; } }; template struct Lt { npy_bool operator()(T a, T b) { return a < b; } }; template struct Gt { npy_bool operator()(T a, T b) { return a > b; } }; template struct Le { npy_bool operator()(T a, T b) { return a <= b; } }; template struct Ge { npy_bool operator()(T a, T b) { return a >= b; } }; template struct Maximum { T operator()(T a, T b) { float fa(a), fb(b); return Eigen::numext::isnan(fa) || fa > fb ? a : b; } }; template struct Minimum { T operator()(T a, T b) { float fa(a), fb(b); return Eigen::numext::isnan(fa) || fa < fb ? a : b; } }; template struct Fmax { T operator()(T a, T b) { float fa(a), fb(b); return Eigen::numext::isnan(fb) || fa > fb ? a : b; } }; template struct Fmin { T operator()(T a, T b) { float fa(a), fb(b); return Eigen::numext::isnan(fb) || fa < fb ? a : b; } }; template struct LogicalNot { npy_bool operator()(T a) { return !static_cast(a); } }; template struct LogicalAnd { npy_bool operator()(T a, T b) { return static_cast(a) && static_cast(b); } }; template struct LogicalOr { npy_bool operator()(T a, T b) { return static_cast(a) || static_cast(b); } }; template struct LogicalXor { npy_bool operator()(T a, T b) { return static_cast(a) ^ static_cast(b); } }; template struct NextAfter { T operator()(T from, T to) { BitsType from_rep = Eigen::numext::bit_cast>(from); BitsType to_rep = Eigen::numext::bit_cast>(to); if (Eigen::numext::isnan(from) || Eigen::numext::isnan(to)) { return std::numeric_limits::quiet_NaN(); } if (from_rep == to_rep) { return to; } auto [from_sign, from_abs] = SignAndMagnitude(from); auto [to_sign, to_abs] = SignAndMagnitude(to); if (from_abs == 0) { if (to_abs == 0) { return to; } else { // Smallest subnormal signed like `to`. return Eigen::numext::bit_cast( static_cast>(0x01 | to_sign)); } } BitsType magnitude_adjustment = (from_abs > to_abs || from_sign != to_sign) ? static_cast>(-1) : static_cast>(1); BitsType out_int = from_rep + magnitude_adjustment; T out = Eigen::numext::bit_cast(out_int); // Some non-IEEE compatible formats may have a representation for NaN // instead of -0, ensure we return a zero in such cases. if constexpr (!std::numeric_limits::is_iec559) { if (Eigen::numext::isnan(out)) { return Eigen::numext::bit_cast(BitsType{0}); } } return out; } }; template struct Spacing { T operator()(T x) { CopySign copysign; if constexpr (!std::numeric_limits::has_infinity) { if (Eigen::numext::abs(x) == std::numeric_limits::max()) { if constexpr (!std::numeric_limits::has_quiet_NaN) return T(); return copysign(std::numeric_limits::quiet_NaN(), x); } } // Compute the distance between the input and the next number with greater // magnitude. The result should have the sign of the input. T away = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : std::numeric_limits::max(); away = copysign(away, x); return NextAfter()(x, away) - x; } }; } // namespace ufuncs } // namespace ml_dtypes #endif // ML_DTYPES_UFUNCS_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/include/000077500000000000000000000000001510671665600207605ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/ml_dtypes/include/float8.h000066400000000000000000002046451510671665600223410ustar00rootroot00000000000000/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_FLOAT8_H_ #define ML_DTYPES_FLOAT8_H_ // 8-bit Floating Point Interchange Format, as described by // https://arxiv.org/abs/2209.05433 // https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1 // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf #include #include #include #include #include #include #include #include #ifdef __has_include #if __has_include() #include #endif #endif #if (defined(__cpp_lib_bitops) && __cpp_lib_bitops >= 201907L) #include #endif #include "Eigen/Core" namespace ml_dtypes { namespace float8_internal { // Forward-declarations of classes. class float8_e3m4; class float8_e4m3; class float8_e4m3fn; class float8_e4m3fnuz; class float8_e4m3b11fnuz; class float8_e5m2; class float8_e5m2fnuz; class float8_e8m0fnu; template class float8_base { protected: // Constructor tag to allow constexpr construction from bit representation. struct ConstructFromRepTag {}; constexpr float8_base(uint8_t rep, ConstructFromRepTag) : rep_{rep} {} public: static constexpr int kBits = 8; constexpr float8_base() : rep_(0) {} template explicit EIGEN_DEVICE_FUNC float8_base( T i, std::enable_if_t, int> = 0) : float8_base(ConvertFrom(static_cast(i)).rep(), ConstructFromRepTag{}) {} template explicit EIGEN_DEVICE_FUNC float8_base( T f, std::enable_if_t, int> = 0) : float8_base(ConvertFrom(f).rep(), ConstructFromRepTag{}) {} explicit EIGEN_DEVICE_FUNC float8_base(Eigen::bfloat16 bf16) : float8_base(ConvertFrom(bf16).rep(), ConstructFromRepTag{}) {} explicit EIGEN_DEVICE_FUNC float8_base(Eigen::half f16) : float8_base(ConvertFrom(f16).rep(), ConstructFromRepTag{}) {} constexpr uint8_t rep() const { return rep_; } template >> explicit EIGEN_DEVICE_FUNC operator T() const { return static_cast(static_cast(derived())); } explicit EIGEN_DEVICE_FUNC operator double() const { return ConvertTo(derived()); } EIGEN_DEVICE_FUNC operator float() const { return ConvertTo(derived()); } EIGEN_DEVICE_FUNC operator Eigen::bfloat16() const { return ConvertTo(derived()); } EIGEN_DEVICE_FUNC operator Eigen::half() const { return ConvertTo(derived()); } explicit EIGEN_DEVICE_FUNC operator bool() const { return (rep() & 0x7F) != 0; } constexpr Derived operator-() const { return Derived(static_cast(rep() ^ 0x80), ConstructFromRepTag{}); } constexpr const Derived& derived() const { return *static_cast(this); } constexpr Derived& derived() { return *static_cast(this); } static constexpr Derived FromRep(uint8_t rep) { return Derived(rep, ConstructFromRepTag{}); } // Conversions allowing saturation and truncation. template static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(From from); template static inline EIGEN_DEVICE_FUNC To ConvertTo(Derived from); // Operators via float32. EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived operator+(const Derived& other) const { return Derived{float{derived()} + float{other}}; } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived operator-(const Derived& other) const { return Derived{float{derived()} - float{other}}; } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived operator*(const Derived& other) const { return Derived{float{derived()} * float{other}}; } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived operator/(const Derived& other) const { return Derived{float{derived()} / float{other}}; } constexpr bool operator==(const Derived& other) const { return Compare(derived(), other) == Ordering::kEquivalent; } constexpr bool operator!=(const Derived& other) const { return Compare(derived(), other) != Ordering::kEquivalent; } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<( const Derived& other) const { return Compare(derived(), other) == Ordering::kLess; } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=( const Derived& other) const { return Compare(derived(), other) <= Ordering::kEquivalent; } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>( const Derived& other) const { return Compare(derived(), other) == Ordering::kGreater; } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=( const Derived& other) const { Ordering ordering = Compare(derived(), other); return ordering == Ordering::kGreater || ordering == Ordering::kEquivalent; } // Compound assignment. EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator+=( const Derived& other) { derived() = derived() + other; return derived(); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator-=( const Derived& other) { derived() = derived() - other; return derived(); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator*=( const Derived& other) { derived() = derived() * other; return derived(); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator/=( const Derived& other) { derived() = derived() / other; return derived(); } private: static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC std::pair SignAndMagnitude(Derived x) { const uint8_t x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); const uint8_t x_bits = Eigen::numext::bit_cast(x); const uint8_t x_sign = (x_bits ^ x_abs_bits) << (CHAR_BIT - Derived::kBits); return {x_sign, x_abs_bits}; } static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int8_t SignAndMagnitudeToTwosComplement(uint8_t sign, uint8_t magnitude) { return magnitude ^ (static_cast(sign) < 0 ? -1 : 0); } enum Ordering : int8_t { kLess = -1, kEquivalent = 0, kGreater = 1, kUnordered = 2, }; EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC friend Ordering Compare( const Derived& lhs, const Derived& rhs) { if (Eigen::numext::isnan(lhs) || Eigen::numext::isnan(rhs)) { return Ordering::kUnordered; } auto [lhs_sign, lhs_mag] = SignAndMagnitude(lhs); auto [rhs_sign, rhs_mag] = SignAndMagnitude(rhs); if (lhs_mag == 0 && rhs_mag == 0) { return Ordering::kEquivalent; } int8_t lhs_twos_complement = SignAndMagnitudeToTwosComplement(lhs_sign, lhs_mag); int8_t rhs_twos_complement = SignAndMagnitudeToTwosComplement(rhs_sign, rhs_mag); if (lhs_twos_complement < rhs_twos_complement) { return Ordering::kLess; } if (lhs_twos_complement > rhs_twos_complement) { return Ordering::kGreater; } return Ordering::kEquivalent; } uint8_t rep_; }; template using RequiresIsDerivedFromFloat8Base = std::enable_if_t, T>, int>; class float8_e3m4 : public float8_base { // Exponent: 3, Mantissa: 4, bias: 3. // IEEE 754. private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e3m4(T f8) : float8_e3m4(ConvertFrom(f8)) {} }; class float8_e4m3 : public float8_base { // Exponent: 4, Mantissa: 3, bias: 7. // IEEE 754. private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e4m3(T f8) : float8_e4m3(ConvertFrom(f8)) {} }; class float8_e4m3fn : public float8_base { // Exponent: 4, Mantissa: 3, bias: 7. // Extended range: no inf, NaN represented by 0bS111'1111. // The "fn" suffix is for consistency with the corresponding LLVM/MLIR type, // signaling this type is not consistent with IEEE-754. The "f" indicates // it is finite values only. The "n" indicates it includes NaNs, but only // at the outer range. private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e4m3fn(T f8) : float8_e4m3fn(ConvertFrom(f8)) {} }; class float8_e4m3b11fnuz : public float8_base { // Exponent: 4, Mantissa: 3, bias: 11. // Extended range: no inf, NaN represented by 0b1000'0000. private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(T f8) : float8_e4m3b11fnuz(ConvertFrom(f8)) {} constexpr float8_e4m3b11fnuz operator-() const { if ((rep() & 0x7f) == 0x00) { return *this; } return Base::operator-(); } float8_e4m3b11fnuz operator-(const float8_e4m3b11fnuz& other) const { return Base::operator-(other); } explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; } }; // Legacy name used in XLA (TODO(jewillco): remove). using float8_e4m3b11 = float8_e4m3b11fnuz; class float8_e4m3fnuz : public float8_base { // 8-bit floating point with 3 bit mantissa. // // An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits // mantissa. The suffix "fnuz" is consistent with LLVM/MLIR naming and is // derived from the differences to IEEE floating point conventions. `F` is // for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for // unsigned zero. // // This type has the following characteristics: // * bit encoding: S1E4M3 - `0bSEEEEMMM` // * exponent bias: 8 // * infinities: Not supported // * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits // set to all 0s - `0b10000000` // * denormals when exponent is 0 private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(T f8) : float8_e4m3fnuz(ConvertFrom(f8)) {} constexpr float8_e4m3fnuz operator-() const { if ((rep() & 0x7f) == 0x00) { return *this; } return Base::operator-(); } float8_e4m3fnuz operator-(const float8_e4m3fnuz& other) const { return Base::operator-(other); } explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; } }; class float8_e5m2 : public float8_base { // Exponent: 5, Mantissa: 2, bias: 15. // IEEE 754. private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e5m2(T f8) : float8_e5m2(ConvertFrom(f8)) {} }; class float8_e5m2fnuz : public float8_base { // 8-bit floating point with 2 bit mantissa. // // An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits // mantissa. The suffix "fnuz" is consistent with LLVM/MLIR naming and is // derived from the differences to IEEE floating point conventions. `F` is // for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for // unsigned zero. // // This type has the following characteristics: // * bit encoding: S1E5M2 - `0bSEEEEEMM` // * exponent bias: 16 // * infinities: Not supported // * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits // set to all 0s - `0b10000000` // * denormals when exponent is 0 private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(T f8) : float8_e5m2fnuz(ConvertFrom(f8)) {} constexpr float8_e5m2fnuz operator-() const { if ((rep() & 0x7f) == 0x00) { return *this; } return Base::operator-(); } float8_e5m2fnuz operator-(const float8_e5m2fnuz& other) const { return Base::operator-(other); } explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; } }; class float8_e8m0fnu : public float8_base { // 8-bit floating point with 8 bit exponent, no sign and zero mantissa. // // See: // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf // // An 8-bit floating point type with no sign bit, 8 bits exponent and 0 bits // mantissa. The suffix "fnuz" is consistent with LLVM/MLIR naming and is // derived from the differences to IEEE floating point conventions. `F` is // for "finite" (no infinities), `N` for with special NaN encoding, `U` for // unsigned. // // This type has the following characteristics: // * bit encoding: S0E8M0 - `0bEEEEEEEE` // * exponent bias: 127 // * infinities: Not supported // * NaNs: Supported with exponent bits set to 1s - `0b11111111` private: using Base = float8_base; friend class float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float8_e8m0fnu(T f8) : float8_e8m0fnu(ConvertFrom(f8)) {} constexpr float8_e8m0fnu operator-() const { // No negative numbers supported in E8M0 => NaN return float8_e8m0fnu::FromRep(0xFF); } float8_e8m0fnu operator-(const float8_e8m0fnu& other) const { return Base::operator-(other); } explicit EIGEN_DEVICE_FUNC operator bool() const { // No zero supported in E8M0 format. return true; } // Comparison simplified to uint8_t compare. EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<( const float8_e8m0fnu& other) const { if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { return false; } return rep() < other.rep(); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=( const float8_e8m0fnu& other) const { if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { return false; } return rep() <= other.rep(); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>( const float8_e8m0fnu& other) const { if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { return false; } return rep() > other.rep(); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=( const float8_e8m0fnu& other) const { if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { return false; } return rep() >= other.rep(); } }; constexpr double ConstexprAbs(double x) { return x < 0.0 ? -x : x; } constexpr double ConstexprCeil(double x) { constexpr double kIntegerThreshold = uint64_t{1} << (std::numeric_limits::digits - 1); // Too big or NaN inputs get returned unchanged. if (!(ConstexprAbs(x) < kIntegerThreshold)) { return x; } const double x_trunc = static_cast(static_cast(x)); return x_trunc < x ? x_trunc + 1.0 : x_trunc; } constexpr double ConstexprFloor(double x) { return -ConstexprCeil(-x); } constexpr double kLog10Of2 = 0.3010299956639812; // C17 5.2.4.2.2p11: // "number of decimal digits, q, such that any floating-point number with q // decimal digits can be rounded into a floating-point number with p radix b // digits and back again without change to the q decimal digits" // floor((p - 1) * log10(2)); constexpr int Digits10FromDigits(int digits) { return static_cast(ConstexprFloor((digits - 1) * kLog10Of2)); } // C17 5.2.4.2.2p11: // "number of decimal digits, n, such that any floating-point number with p // radix b digits can be rounded to a floating-point number with n decimal // digits and back again without change to the value" // ceil(1 + p * log10(2)); constexpr int MaxDigits10FromDigits(int digits) { return static_cast(ConstexprCeil(1.0 + (digits * kLog10Of2))); } // C17 5.2.4.2.2p11: // "minimum negative integer such that 10 raised to that power is in the range // of normalized floating-point numbers" // ceil(log10(2**(emin - 1))) == ceil((emin - 1) * log10(2)); constexpr int MinExponent10FromMinExponent(int min_exponent) { return static_cast(ConstexprCeil((min_exponent - 1) * kLog10Of2)); } // C17 5.2.4.2.2p11: // "maximum integer such that 10 raised to that power is in the range of // representable finite floating-point numbers" // floor(log10((1 - 2**-p) * 2**emax)) == floor(log10(1 - 2**-p) + // emax * log10(2)) constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent, int digits) { // We only support digits in {1,2,3,4,5}. This table would grow if we wanted // to handle more values. constexpr double kLog10OfOnePredecessor[] = { // log10(1 - 2**-1) -0.3010299956639812, // log10(1 - 2**-2) -0.12493873660829993, // log10(1 - 2**-3) -0.057991946977686754, // log10(1 - 2**-4) -0.028028723600243537, // log10(1 - 2**-5) -0.013788284485633295, }; return static_cast(ConstexprFloor(kLog10OfOnePredecessor[digits - 1] + max_exponent * kLog10Of2)); } // Structures for use in specializing std::numeric_limits. struct numeric_limits_float8_base { // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const bool is_specialized = true; static inline constexpr const bool is_signed = true; static inline constexpr const bool is_integer = false; static inline constexpr const bool is_exact = false; static inline constexpr const bool has_quiet_NaN = true; // has_denorm and has_denorm_loss are deprecated in C++23. #if !defined(__cplusplus) || __cplusplus < 202302L static inline constexpr const std::float_denorm_style has_denorm = std::denorm_present; static inline constexpr const bool has_denorm_loss = false; #endif static inline constexpr const std::float_round_style round_style = std::round_to_nearest; static inline constexpr const bool is_bounded = true; static inline constexpr const bool is_modulo = false; static inline constexpr const int radix = std::numeric_limits::radix; static inline constexpr const bool traps = std::numeric_limits::traps; static inline constexpr const bool tinyness_before = std::numeric_limits::tinyness_before; // NOLINTEND }; struct numeric_limits_float8_e3m4 : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 3; static inline constexpr const int kMantissaBits = 4; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); static inline constexpr const int max_exponent = 0b111 - kExponentBias; static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = true; static inline constexpr const bool has_infinity = true; static inline constexpr const bool has_signaling_NaN = true; // NOLINTEND // 1.0 * 2^(0b001 - 3) = 1.0 * 2^-2 = 1/4 (min normal) static constexpr float8_e3m4 min() { return float8_e3m4::FromRep(1 << kMantissaBits); } // -(1 + 0b1111 * 2^-2) * 2^(0b110 - 3) = -(1 + 15/16) * 2^3 = -15.5 static constexpr float8_e3m4 lowest() { return float8_e3m4::FromRep(0b1'110'1111); } // (1 + 0b1111 * 2^-2) * 2^(0b110 - 3) = (1 + 15/16) * 2^3 = 15.5 static constexpr float8_e3m4 max() { return float8_e3m4::FromRep(0b0'110'1111); } // (1 + 1/16) * 2^0 - 1.0 = 1.0 + 1/16 - 1.0 = 1/16 // Encoded as denormal number 2^-2 * 1/4 static constexpr float8_e3m4 epsilon() { return float8_e3m4::FromRep(0b0'000'0100); } // 1.0 * 2^-1 = 0.5 static constexpr float8_e3m4 round_error() { return float8_e3m4::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e3m4 infinity() { return float8_e3m4::FromRep(0b0'111'0000); } static constexpr float8_e3m4 quiet_NaN() { // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set // to 0 or 1 and all the bits of the biased exponent field E set to 1 // (see 3.4). A quiet NaN bit string should be encoded with the first bit // (d1) of the trailing significand field T being 1." return float8_e3m4::FromRep(0b0'111'1000); } static constexpr float8_e3m4 signaling_NaN() { // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with // the first bit of the trailing significand field being 0." return float8_e3m4::FromRep(0b0'111'0100); } // 2^(-2) * 2^(-4) = 2^-6 = 1/64 (min denormal) static constexpr float8_e3m4 denorm_min() { return float8_e3m4::FromRep(0b0'000'0001); } }; struct numeric_limits_float8_e4m3 : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 7; static inline constexpr const int kMantissaBits = 3; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); static inline constexpr const int max_exponent = 0b1111 - kExponentBias; static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = true; static inline constexpr const bool has_infinity = true; static inline constexpr const bool has_signaling_NaN = true; // NOLINTEND // 1.0 * 2^(0b0001 - 7) = 1.0 * 2^-6 = 1/64 (min normal) static constexpr float8_e4m3 min() { return float8_e4m3::FromRep(1 << kMantissaBits); } // -(1 + 0b111 * 2^-2) * 2^(0b1110 - 7) = -(1 + 7/8) * 2^7 = -240 static constexpr float8_e4m3 lowest() { return float8_e4m3::FromRep(0b1'1110'111); } // (1 + 0b111 * 2^-2) * 2^(0b1110 - 7) = (1 + 7/8) * 2^7 = 240 static constexpr float8_e4m3 max() { return float8_e4m3::FromRep(0b0'1110'111); } // 1.0 * 2^-3 = 0.125 static constexpr float8_e4m3 epsilon() { return float8_e4m3::FromRep((-kMantissaBits + kExponentBias) << kMantissaBits); } // 1.0 * 2^-1 = 0.5 static constexpr float8_e4m3 round_error() { return float8_e4m3::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e4m3 infinity() { return float8_e4m3::FromRep(0b0'1111'000); } static constexpr float8_e4m3 quiet_NaN() { // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set // to 0 or 1 and all the bits of the biased exponent field E set to 1 // (see 3.4). A quiet NaN bit string should be encoded with the first bit // (d1) of the trailing significand field T being 1." return float8_e4m3::FromRep(0b0'1111'100); } static constexpr float8_e4m3 signaling_NaN() { // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with // the first bit of the trailing significand field being 0." return float8_e4m3::FromRep(0b0'1111'001); } // 2^(-6) * 2^(-3) = 2^-9 = 1/512 (min denormal) static constexpr float8_e4m3 denorm_min() { return float8_e4m3::FromRep(0b0'0000'001); } }; struct numeric_limits_float8_e4m3fn : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 7; static inline constexpr const int kMantissaBits = 3; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); static inline constexpr const int max_exponent = (0b1111 - kExponentBias) + 1; // Extended format. static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = false; static inline constexpr const bool has_infinity = false; static inline constexpr const bool has_signaling_NaN = false; // NOLINTEND // 1.0 * 2^(0b0001 - 7) = 1.0 * 2^-6 = 0.015625 static constexpr float8_e4m3fn min() { return float8_e4m3fn::FromRep(0b0'0001 << kMantissaBits); } // -(1 + 0b110 * 2^-3) * 2^(0b1111 - 7) = -1.75 * 2^8 = -448 static constexpr float8_e4m3fn lowest() { return float8_e4m3fn::FromRep(0b1'1111'110); } // (1 + 0b110 * 2^-3) * 2**(0b1111 - 7) = 1.75 * 2^8 = 448 static constexpr float8_e4m3fn max() { return float8_e4m3fn::FromRep(0b0'1111'110); } // 1.0 * 2^-3 = 0.125 static constexpr float8_e4m3fn epsilon() { return float8_e4m3fn::FromRep((-kMantissaBits + kExponentBias) << kMantissaBits); } // 1.0 * 2^-1 = 0.5 static constexpr float8_e4m3fn round_error() { return float8_e4m3fn::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e4m3fn infinity() { return float8_e4m3fn::FromRep(0b0'1111'111); } // NaN. static constexpr float8_e4m3fn quiet_NaN() { return float8_e4m3fn::FromRep(0b0'1111'111); } static constexpr float8_e4m3fn signaling_NaN() { return float8_e4m3fn::FromRep(0b0'1111'111); } // 1.0 * 2^(-7 - 3 + 1) = 1.0 * 2^-9 = 0.001953125 static constexpr float8_e4m3fn denorm_min() { return float8_e4m3fn::FromRep(0b0'0000'001); } }; struct numeric_limits_float8_e4m3b11fnuz : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 11; static inline constexpr const int kMantissaBits = 3; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); static inline constexpr const int max_exponent = (0b1111 - kExponentBias) + 1; // Extended format. static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = false; static inline constexpr const bool has_infinity = false; static inline constexpr const bool has_signaling_NaN = false; // NOLINTEND // 1.0 * 2^(0b0001 - 11) = 1.0 * 2^-10 = 0.0009765625 static constexpr float8_e4m3b11fnuz min() { return float8_e4m3b11fnuz::FromRep(1 << kMantissaBits); } // -(1 + 0b111 * 2^-3) * 2^(0b1111 - 11) = -1.875 * 2^4 = -30 static constexpr float8_e4m3b11fnuz lowest() { return float8_e4m3b11fnuz::FromRep(0b1'1111'111); } // (1 + 0b111 * 2^-3) * 2^(0b1111 - 11) = 1.875 * 2^4 = 30 static constexpr float8_e4m3b11fnuz max() { return float8_e4m3b11fnuz::FromRep(0b0'1111'111); } // 1.0 * 2^-3 = 0.125 static constexpr float8_e4m3b11fnuz epsilon() { return float8_e4m3b11fnuz::FromRep((-kMantissaBits + kExponentBias) << kMantissaBits); } // 1.0 * 2^-1 = 0.5 static constexpr float8_e4m3b11fnuz round_error() { return float8_e4m3b11fnuz::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e4m3b11fnuz infinity() { return float8_e4m3b11fnuz::FromRep(0b1'0000'000); } // NaN. static constexpr float8_e4m3b11fnuz quiet_NaN() { return float8_e4m3b11fnuz::FromRep(0b1'0000'000); } static constexpr float8_e4m3b11fnuz signaling_NaN() { return float8_e4m3b11fnuz::FromRep(0b1'0000'000); } // 1.0 * 2^(-11 - 3 + 1) = 1.0 * 2^-13 = 0.0001220703125 static constexpr float8_e4m3b11fnuz denorm_min() { return float8_e4m3b11fnuz::FromRep(0b0'0000'001); } }; struct numeric_limits_float8_e4m3fnuz : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 8; static inline constexpr const int kMantissaBits = 3; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); static inline constexpr const int max_exponent = (0b1111 - kExponentBias) + 1; // Extended format. static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = false; static inline constexpr const bool has_infinity = false; static inline constexpr const bool has_signaling_NaN = false; // NOLINTEND static constexpr float8_e4m3fnuz min() { return float8_e4m3fnuz::FromRep(0x08); } static constexpr float8_e4m3fnuz lowest() { return float8_e4m3fnuz::FromRep(0xFF); } static constexpr float8_e4m3fnuz max() { return float8_e4m3fnuz::FromRep(0x7F); } static constexpr float8_e4m3fnuz epsilon() { return float8_e4m3fnuz::FromRep((-kMantissaBits + kExponentBias) << kMantissaBits); } static constexpr float8_e4m3fnuz round_error() { return float8_e4m3fnuz::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e4m3fnuz infinity() { return float8_e4m3fnuz::FromRep(0x80); } // NaN. static constexpr float8_e4m3fnuz quiet_NaN() { return float8_e4m3fnuz::FromRep(0x80); } static constexpr float8_e4m3fnuz signaling_NaN() { return float8_e4m3fnuz::FromRep(0x80); } static constexpr float8_e4m3fnuz denorm_min() { return float8_e4m3fnuz::FromRep(0x01); } }; struct numeric_limits_float8_e5m2 : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 15; static inline constexpr const int kMantissaBits = 2; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); static inline constexpr const int max_exponent = 0b11111 - kExponentBias; static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = true; static inline constexpr const bool has_infinity = true; static inline constexpr const bool has_signaling_NaN = true; // NOLINTEND // 1.0 * 2^(0b00001 - 15) = 1.0 * 2^-14 = 0.00006103515625 static constexpr float8_e5m2 min() { return float8_e5m2::FromRep(1 << kMantissaBits); } // -(1 + 0b11 * 2^-2) * 2^(0b11110 - 15) = -1.75 * 2^15 = -57344 static constexpr float8_e5m2 lowest() { return float8_e5m2::FromRep(0b1'11110'11); } // (1 + 0b11 * 2^-2) * 2^(0b11110 - 15) = 1.75 * 2^15 = 57344 static constexpr float8_e5m2 max() { return float8_e5m2::FromRep(0b0'11110'11); } // 1.0 * 2^-2 = 0.25 static constexpr float8_e5m2 epsilon() { return float8_e5m2::FromRep((-kMantissaBits + kExponentBias) << kMantissaBits); } // 1.0 * 2^-1 = 0.5 static constexpr float8_e5m2 round_error() { return float8_e5m2::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e5m2 infinity() { return float8_e5m2::FromRep(0b0'11111'00); } static constexpr float8_e5m2 quiet_NaN() { // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set // to 0 or 1 and all the bits of the biased exponent field E set to 1 // (see 3.4). A quiet NaN bit string should be encoded with the first bit // (d1) of the trailing significand field T being 1." return float8_e5m2::FromRep(0b0'11111'10); } static constexpr float8_e5m2 signaling_NaN() { // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with // the first bit of the trailing significand field being 0." return float8_e5m2::FromRep(0b0'11111'01); } // 1.0 * 2^(-15 - 2 + 1) = 1.0 * 2^-16 = 0.0000152587890625 static constexpr float8_e5m2 denorm_min() { return float8_e5m2::FromRep(0b0'00000'01); } }; struct numeric_limits_float8_e5m2fnuz : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 16; static inline constexpr const int kMantissaBits = 2; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); static inline constexpr const int max_exponent = (0b11111 - kExponentBias) + 1; static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = false; static inline constexpr const bool has_infinity = false; static inline constexpr const bool has_signaling_NaN = false; // NOLINTEND static constexpr float8_e5m2fnuz min() { return float8_e5m2fnuz::FromRep(0x04); } static constexpr float8_e5m2fnuz lowest() { return float8_e5m2fnuz::FromRep(0xFF); } static constexpr float8_e5m2fnuz max() { return float8_e5m2fnuz::FromRep(0x7F); } static constexpr float8_e5m2fnuz epsilon() { return float8_e5m2fnuz::FromRep((-kMantissaBits + kExponentBias) << kMantissaBits); } static constexpr float8_e5m2fnuz round_error() { return float8_e5m2fnuz::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e5m2fnuz infinity() { return float8_e5m2fnuz::FromRep(0x80); } // NaN. static constexpr float8_e5m2fnuz quiet_NaN() { return float8_e5m2fnuz::FromRep(0x80); } static constexpr float8_e5m2fnuz signaling_NaN() { return float8_e5m2fnuz::FromRep(0x80); } static constexpr float8_e5m2fnuz denorm_min() { return float8_e5m2fnuz::FromRep(0x01); } }; struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 127; static inline constexpr const int kMantissaBits = 0; public: // NOLINTBEGIN: these names must match std::numeric_limits. static inline constexpr const bool is_signed = false; // has_denorm and has_denorm_loss are deprecated in C++23. #if !defined(__cplusplus) || __cplusplus < 202302L static inline constexpr const std::float_denorm_style has_denorm = std::denorm_absent; #endif static inline constexpr const int digits = kMantissaBits + 1; static inline constexpr const int digits10 = Digits10FromDigits(digits); static inline constexpr const int max_digits10 = MaxDigits10FromDigits(digits); // 2**-127 smallest valid normalized value.. static inline constexpr const int min_exponent = -kExponentBias + 1; static inline constexpr const int min_exponent10 = MinExponent10FromMinExponent(min_exponent); // 128 encoding using for NaN static inline constexpr const int max_exponent = kExponentBias + 1; static inline constexpr const int max_exponent10 = MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static inline constexpr const bool is_iec559 = false; static inline constexpr const bool has_infinity = false; static inline constexpr const bool has_signaling_NaN = false; // NOLINTEND static constexpr float8_e8m0fnu min() { return float8_e8m0fnu::FromRep(0x00); } static constexpr float8_e8m0fnu lowest() { return float8_e8m0fnu::FromRep(0x00); } static constexpr float8_e8m0fnu max() { return float8_e8m0fnu::FromRep(0xfe); } static constexpr float8_e8m0fnu epsilon() { return float8_e8m0fnu::FromRep((-kMantissaBits + kExponentBias) << kMantissaBits); } static constexpr float8_e8m0fnu round_error() { return float8_e8m0fnu::FromRep((-1 + kExponentBias) << kMantissaBits); } static constexpr float8_e8m0fnu infinity() { return float8_e8m0fnu::FromRep(0xFF); } // NaN. static constexpr float8_e8m0fnu quiet_NaN() { return float8_e8m0fnu::FromRep(0xFF); } static constexpr float8_e8m0fnu signaling_NaN() { return float8_e8m0fnu::FromRep(0xFF); } static constexpr float8_e8m0fnu denorm_min() { // No denorm => smallest value. return float8_e8m0fnu::FromRep(0x00); } }; } // namespace float8_internal } // namespace ml_dtypes namespace std { // Standard-library overrides. Note that these are picked up by Eigen as well. template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e3m4 {}; template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3 {}; template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3fn {}; template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3b11fnuz {}; template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3fnuz {}; template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e5m2 {}; template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e5m2fnuz {}; template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e8m0fnu {}; } // namespace std namespace ml_dtypes { namespace float8_internal { constexpr inline float8_e3m4 abs(const float8_e3m4& a) { return float8_e3m4::FromRep(a.rep() & 0b0'111'1111); } constexpr inline bool(isnan)(const float8_e3m4& a) { return abs(a).rep() > std::numeric_limits::infinity().rep(); } constexpr inline float8_e4m3 abs(const float8_e4m3& a) { return float8_e4m3::FromRep(a.rep() & 0b0'1111'111); } constexpr inline bool(isnan)(const float8_e4m3& a) { return abs(a).rep() > std::numeric_limits::infinity().rep(); } // Free-functions for use with ADL and in Eigen. constexpr inline float8_e4m3fn abs(const float8_e4m3fn& a) { return float8_e4m3fn::FromRep(a.rep() & 0b0'1111'111); } constexpr inline bool(isnan)(const float8_e4m3fn& a) { return abs(a).rep() == std::numeric_limits::quiet_NaN().rep(); } constexpr inline float8_e4m3b11fnuz abs(const float8_e4m3b11fnuz& a) { return (a.rep() & 0b0'1111'111) == 0 ? float8_e4m3b11fnuz::FromRep(a.rep()) : float8_e4m3b11fnuz::FromRep(a.rep() & 0b0'1111'111); } constexpr inline bool(isnan)(const float8_e4m3b11fnuz& a) { return a.rep() == std::numeric_limits::quiet_NaN().rep(); } constexpr inline float8_e4m3fnuz abs(const float8_e4m3fnuz& a) { return (a.rep() & 0x7F) == 0 ? float8_e4m3fnuz::FromRep(a.rep()) : float8_e4m3fnuz::FromRep(a.rep() & 0x7F); } constexpr inline bool(isnan)(const float8_e4m3fnuz& a) { return abs(a).rep() == std::numeric_limits::quiet_NaN().rep(); } constexpr inline float8_e5m2 abs(const float8_e5m2& a) { return float8_e5m2::FromRep(a.rep() & 0b0'11111'11); } constexpr inline bool(isnan)(const float8_e5m2& a) { return abs(a).rep() > std::numeric_limits::infinity().rep(); } constexpr inline float8_e5m2fnuz abs(const float8_e5m2fnuz& a) { return (a.rep() & 0x7F) == 0 ? float8_e5m2fnuz::FromRep(a.rep()) : float8_e5m2fnuz::FromRep(a.rep() & 0x7F); } constexpr inline bool(isnan)(const float8_e5m2fnuz& a) { return a.rep() == 0x80; } constexpr inline float8_e8m0fnu abs(const float8_e8m0fnu& a) { return a; } constexpr inline bool(isnan)(const float8_e8m0fnu& a) { return a.rep() == 0xff; } template constexpr inline bool(isinf)(const float8_base& a) { if constexpr (std::numeric_limits::has_infinity) { return abs(a.derived()).rep() == std::numeric_limits::infinity().rep(); } else { // No inf representation. return false; } } template constexpr inline bool(isfinite)(const float8_base& a) { return !isnan(a.derived()) && !isinf(a.derived()); } template std::ostream& operator<<(std::ostream& os, const float8_base& f8) { os << static_cast(f8.derived()); return os; } //============================================================================== // Inline conversion routines between float8 and other types. //============================================================================== template bool constexpr IsPowerOfTwo(T x) { return (x != 0) && ((x & (x - 1)) == 0); } // Helper for getting a bytes size which is a power of two. template struct NextPowerOfTwo { static constexpr int value = Size; }; template <> struct NextPowerOfTwo<3> { static constexpr int value = 4; }; template <> struct NextPowerOfTwo<5> { static constexpr int value = 8; }; template <> struct NextPowerOfTwo<6> { static constexpr int value = 8; }; template <> struct NextPowerOfTwo<7> { static constexpr int value = 8; }; // Helper for getting a bit representation provided a byte size. template using GetUnsignedInteger = typename Eigen::numext::get_integer_by_size::unsigned_type; // Converts between two floating-point types. template struct ConvertImpl; // Convert to same type. We need explicit specializations for all combinations // of template parameters to avoid ambiguities. template struct IdentityConversion { static EIGEN_DEVICE_FUNC inline Scalar run(Scalar from) { return from; } }; template struct ConvertImpl : public IdentityConversion {}; template struct TraitsBase { using BitsType = GetUnsignedInteger; static constexpr bool kIsSigned = std::numeric_limits::is_signed; static constexpr bool kHasZero = true; static constexpr int kBits = sizeof(Float) * CHAR_BIT; static constexpr int kMantissaBits = Eigen::NumTraits::digits() - 1; // Extra bit used in exponent for unsigned float. static constexpr int kExponentBits = kBits - kMantissaBits - static_cast(kIsSigned); static constexpr BitsType kExponentMask = ((BitsType{1} << kExponentBits) - 1) << kMantissaBits; static constexpr BitsType kMantissaMask = (BitsType{1} << kMantissaBits) - 1; static constexpr int kExponentBias = (1 << (kExponentBits - 1)) - 1; }; template struct Traits : public TraitsBase {}; template <> struct Traits : public TraitsBase { static constexpr int kExponentBias = 11; }; template <> struct Traits : public TraitsBase { using Base = TraitsBase; static constexpr int kExponentBias = Base::kExponentBias + 1; }; template <> struct Traits : public TraitsBase { using Base = TraitsBase; static constexpr int kExponentBias = Base::kExponentBias + 1; }; template <> struct Traits : public TraitsBase { using Base = TraitsBase; // No zero in E8MO OCP MX format description. static constexpr bool kHasZero = false; }; template constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff, bool use_implicit_bit) { // Round to nearest even by adding a bias term. // Consider a bit pattern // FFF...FLRTT...T, // where bits RTT...T need to be rounded-off. We add a bias term to the // bit pattern s.t. a carry is introduced to round up only if // - L is 1, R is 1, OR // - L is 0, R is 1, any T is one. // We do this by adding L to a bit pattern consisting of all T = 1. // // When rounding to zero mantissa (E8M0 type), the L bit is implicitly 1 (do // not use the exponent bits for rounding). Add only the R bit in this case. Bits bias = !use_implicit_bit ? ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1 : Bits{1} << (roundoff - 1); return bits + bias; } #if (defined(__cpp_lib_bitops) && __cpp_lib_bitops >= 201907L) using std::countl_zero; #else static constexpr inline int countl_zero(uint64_t x) { int zeroes = 60; if (x >> 32) { zeroes -= 32; x >>= 32; } if (x >> 16) { zeroes -= 16; x >>= 16; } if (x >> 8) { zeroes -= 8; x >>= 8; } if (x >> 4) { zeroes -= 4; x >>= 4; } return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes; } static constexpr inline int countl_zero(uint32_t x) { int zeroes = 28; if (x >> 16) { zeroes -= 16; x >>= 16; } if (x >> 8) { zeroes -= 8; x >>= 8; } if (x >> 4) { zeroes -= 4; x >>= 4; } return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes; } static constexpr inline int countl_zero(uint16_t x) { int zeroes = 12; if (x >> 8) { zeroes -= 8; x >>= 8; } if (x >> 4) { zeroes -= 4; x >>= 4; } return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes; } static constexpr inline int countl_zero(uint8_t x) { int zeroes = 4; if (x >> 4) { zeroes -= 4; x >>= 4; } return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes; } #endif template struct ConvertImpl>> { using FromTraits = Traits; using FromBits = typename FromTraits::BitsType; static constexpr bool kFromIsSigned = FromTraits::kIsSigned; static constexpr bool kFromHasZero = FromTraits::kHasZero; static constexpr int kFromBits = FromTraits::kBits; static constexpr int kFromMantissaBits = FromTraits::kMantissaBits; static constexpr int kFromExponentBits = FromTraits::kExponentBits; static constexpr int kFromExponentBias = FromTraits::kExponentBias; static constexpr FromBits kFromExponentMask = FromTraits::kExponentMask; using ToTraits = Traits; using ToBits = typename ToTraits::BitsType; static constexpr bool kToIsSigned = ToTraits::kIsSigned; static constexpr bool kToHasZero = ToTraits::kHasZero; static constexpr int kToBits = ToTraits::kBits; static constexpr int kToMantissaBits = ToTraits::kMantissaBits; static constexpr int kToExponentBits = ToTraits::kExponentBits; static constexpr int kToExponentBias = ToTraits::kExponentBias; static constexpr ToBits kToExponentMask = ToTraits::kExponentMask; // `WideBits` is wide enough to accommodate the largest exponent and mantissa // in either `From` or `To`. static constexpr int kWideBits = (std::max(kToMantissaBits, kFromMantissaBits)) + // Max significand. (std::max(kToExponentBits, kFromExponentBits)); // Max exponent. static constexpr int kWideBytesRaw = (kWideBits + (CHAR_BIT - 1)) / CHAR_BIT; // Need a power of two (i.e. not 3 bytes). static constexpr int kWideBytes = NextPowerOfTwo::value; using WideBits = GetUnsignedInteger; static_assert(!std::is_void_v, "`WideBits` type can not be void type."); static constexpr int kExponentOffset = kToExponentBias - kFromExponentBias; static constexpr int kDigitShift = kToMantissaBits - kFromMantissaBits; static EIGEN_DEVICE_FUNC inline To run(From from) { // Shift bits to destination type, without sign bit. const bool from_sign_bit = Eigen::numext::bit_cast(from) >> (kFromBits - 1) && kFromIsSigned; const FromBits from_bits = Eigen::numext::bit_cast(Eigen::numext::abs(from)); // Special values, preserving sign. if (Eigen::numext::isinf(from)) { return from_sign_bit ? -Eigen::NumTraits::infinity() : Eigen::NumTraits::infinity(); } if (Eigen::numext::isnan(from)) { return from_sign_bit ? -Eigen::NumTraits::quiet_NaN() : Eigen::NumTraits::quiet_NaN(); } // Dealing with zero, when `From` has one. if (from_bits == 0 && kFromHasZero) { if constexpr (kToHasZero) { // Keep the sign, if `To` supports it. return from_sign_bit && kToIsSigned ? -To{} : To{}; } else { return kSaturate ? std::numeric_limits::denorm_min() : Eigen::NumTraits::quiet_NaN(); } } // `To` unsigned floating format: NaN or saturate. if constexpr (!kToIsSigned && kFromIsSigned) { if (from_sign_bit) { return kSaturate ? std::numeric_limits::lowest() : Eigen::NumTraits::quiet_NaN(); } } const int biased_from_exponent = from_bits >> kFromMantissaBits; const bool to_zero_mantissa = kToMantissaBits == 0; // `To` supports more exponents near zero which means that some subnormal // values in `From` may become normal. if constexpr (std::numeric_limits::min_exponent < std::numeric_limits::min_exponent) { if (biased_from_exponent == 0) { // Subnormals. WideBits bits = from_bits; // Determine exponent in target type. const int msb = sizeof(from_bits) * CHAR_BIT - countl_zero(from_bits) - 1; const int normalization_factor = kFromMantissaBits - msb; const int biased_exponent = kExponentOffset - normalization_factor + 1; if (biased_exponent <= 0) { // Result is subnormal. Adjust the subnormal bits to account for // the difference in exponent bias. if constexpr (kExponentOffset < sizeof(WideBits) * CHAR_BIT) { bits <<= kExponentOffset; } } else { // Result is normal. Shift the mantissa to account for the number of // leading zero digits, and clear the hidden bit. bits <<= normalization_factor; bits &= ~(WideBits{1} << kFromMantissaBits); // Insert the exponent bits. bits |= static_cast(biased_exponent) << kFromMantissaBits; } // Truncate/round mantissa if necessary. if constexpr (kDigitShift >= 0) { bits <<= kDigitShift; } else { if constexpr (!kTruncate) { // When converting float to e8m0, the bits represent a denormal, // so don't use the implicit mantissa bit for rounding. bits = RoundBitsToNearestEven( bits, -kDigitShift, to_zero_mantissa && kExponentOffset != 0); } bits >>= -kDigitShift; } To to = Eigen::numext::bit_cast(static_cast(bits)); return from_sign_bit ? -to : to; } } // `To` supports fewer exponents near zero which means that some values in // `From` may become subnormal. if constexpr (std::numeric_limits::min_exponent > std::numeric_limits::min_exponent) { const int unbiased_exponent = biased_from_exponent - kFromExponentBias; const int biased_to_exponent = unbiased_exponent + kToExponentBias; // Subnormals and zero. if (biased_to_exponent <= 0) { // Round and shift mantissa down. // Zero exponent valid if From has no zero representation. FromBits from_has_leading_one = (biased_from_exponent > 0 || !kFromHasZero ? 1 : 0); int exponent_shift = -kDigitShift - biased_to_exponent + from_has_leading_one; // Insert the implicit leading 1 bit on the mantissa for normalized // inputs. FromBits rounded_from_bits = (from_bits & FromTraits::kMantissaMask) | (from_has_leading_one << kFromMantissaBits); ToBits bits = 0; if (exponent_shift > 0) { // To avoid UB, limit rounding and shifting to the full mantissa plus // leading 1. if (exponent_shift <= kFromMantissaBits + 1) { if constexpr (!kTruncate) { // NOTE: we need to round again from the original from_bits, // otherwise the lower precision bits may already be lost. There // is an edge-case where rounding to a normalized value would // normally round down, but for a subnormal, we need to round up. rounded_from_bits = RoundBitsToNearestEven(rounded_from_bits, exponent_shift, false); } bits = rounded_from_bits >> exponent_shift; } } else { bits = rounded_from_bits << -exponent_shift; } // Insert sign and return. To to = Eigen::numext::bit_cast(bits); return from_sign_bit ? -to : to; } } // Round the mantissa if it is shrinking. WideBits rounded_from_bits = from_bits; if constexpr (kDigitShift < 0) { if constexpr (!kTruncate) { rounded_from_bits = RoundBitsToNearestEven(from_bits, -kDigitShift, to_zero_mantissa); } // Zero-out tail bits. rounded_from_bits &= ~((WideBits{1} << (-kDigitShift)) - 1); } // Re-bias the exponent. rounded_from_bits += static_cast(kExponentOffset) << kFromMantissaBits; ToBits bits; // Check for overflows by aligning the significands. We always align the // narrower significand to the wider significand. const WideBits kToHighestRep = Eigen::numext::bit_cast(Eigen::NumTraits::highest()); WideBits aligned_highest{kToHighestRep}; if constexpr (kDigitShift < 0) { aligned_highest <<= -kDigitShift; // Shift down, all dropped bits should already be zero. bits = static_cast(rounded_from_bits >> -kDigitShift); } else if constexpr (kDigitShift >= 0) { // Shift up, inserting zeros in the newly created digits. rounded_from_bits <<= kDigitShift; bits = static_cast(rounded_from_bits); } To to = Eigen::numext::bit_cast(bits); // `From` supports larger values than `To`, we may overflow. if constexpr (std::make_pair(std::numeric_limits::max_exponent, std::numeric_limits::digits) < std::make_pair(std::numeric_limits::max_exponent, std::numeric_limits::digits)) { if (rounded_from_bits > aligned_highest) { // Overflowed values map to highest or infinity depending on kSaturate. to = kSaturate ? Eigen::NumTraits::highest() : Eigen::NumTraits::infinity(); } } // Insert sign bit. return from_sign_bit ? -to : to; } }; // Saturation has no impact when casting e4m3fn to e5m2. template struct ConvertImpl { static EIGEN_DEVICE_FUNC inline float8_e5m2 run(float8_e4m3fn from) { return ConvertImpl::run(from); } }; template struct ConvertImpl { static EIGEN_DEVICE_FUNC inline float8_e5m2 run(Eigen::half from) { uint16_t from_bits = Eigen::numext::bit_cast(from); // Special values (Inf or NaN). uint16_t abs_bits = from_bits & 0x7FFF; if (abs_bits == 0x7C00) { return float8_e5m2::FromRep(from_bits >> 8); } else if (abs_bits > 0x7C00) { // IEEE 754-2019 6.2.1: "A quiet NaN bit string should be encoded with the // first bit (d1) of the trailing significand field T being 1." // IEEE 754-2019 6.2.3: "Conversion of a quiet NaN to a floating-point // format of the same or a different radix that does not allow the payload // to be preserved, shall return a quiet NaN [...]" return float8_e5m2::FromRep((from_bits >> 8) | 0b0'00000'10); } if constexpr (!kTruncate) { from_bits = RoundBitsToNearestEven(from_bits, 8, false); // Rounding can cause an overflow to infinity. Clamp to the largest finite // value if saturation is requested. if constexpr (kSaturate) { const float8_e5m2 kHighest = Eigen::NumTraits::highest(); if ((from_bits & 0x7F00) > static_cast(kHighest.rep()) << 8) { const bool from_sign_bit = from_bits >> 15; return from_sign_bit ? -kHighest : kHighest; } } } return float8_e5m2::FromRep(from_bits >> 8); } }; // Direct casts of e5m2 to Eigen::half simply shifts bits over. template struct ConvertImpl { static EIGEN_DEVICE_FUNC inline Eigen::half run(float8_e5m2 from) { return Eigen::numext::bit_cast( static_cast(static_cast(from.rep()) << 8)); } }; template template EIGEN_DEVICE_FUNC Derived float8_base::ConvertFrom(const From from) { // We are rounding long double -> float -> float8. This can induce // double-rounding which may alter the results. We can correct for this using // a trick explained in: Boldo, Sylvie, and Guillaume Melquiond. "When double // rounding is odd." 17th IMACS World Congress. 2005. if constexpr (std::is_floating_point_v && sizeof(From) > sizeof(double)) { // float80, binary128, etc. end up here. static_assert(std::numeric_limits::digits >= std::numeric_limits::digits + 2); static_assert(std::numeric_limits::min_exponent >= std::numeric_limits::min_exponent + 2); static_assert(std::numeric_limits::is_iec559); static_assert(std::numeric_limits::radix == 2); const bool is_negative = std::signbit(from); const From abs_wide = std::fabs(from); float abs_narrow = static_cast(abs_wide); const From abs_narrow_as_wide = static_cast(abs_narrow); uint32_t narrow_bits = Eigen::numext::bit_cast(abs_narrow); // We can keep the narrow value as-is if narrowing was exact (no rounding // error), the wide value was NaN (the narrow value is also NaN and should // be preserved) or if we rounded to the odd value. const bool keep_narrow = (abs_wide == abs_narrow_as_wide) || std::isnan(abs_narrow) || (narrow_bits & 1); // We morally performed a round-down if `abs_narrow` is smaller than // `abs_wide`. const bool narrow_is_rd = abs_wide > abs_narrow_as_wide; // If the narrow value is odd or exact, pick it. // Otherwise, narrow is even and corresponds to either the rounded-up or // rounded-down value. If narrow is the rounded-down value, we want the // rounded-up value as it will be odd. narrow_bits += keep_narrow ? 0 : narrow_is_rd ? 1 : -1; abs_narrow = Eigen::numext::bit_cast(narrow_bits); return ConvertImpl::run( is_negative ? -abs_narrow : abs_narrow); } else { return ConvertImpl::run(from); } } template template EIGEN_DEVICE_FUNC To float8_base::ConvertTo(Derived from) { return ConvertImpl::run(from); } } // namespace float8_internal // Exported types. using float8_e3m4 = float8_internal::float8_e3m4; using float8_e4m3 = float8_internal::float8_e4m3; using float8_e4m3fn = float8_internal::float8_e4m3fn; using float8_e4m3fnuz = float8_internal::float8_e4m3fnuz; using float8_e4m3b11fnuz = float8_internal::float8_e4m3b11fnuz; using float8_e5m2 = float8_internal::float8_e5m2; using float8_e5m2fnuz = float8_internal::float8_e5m2fnuz; using float8_e8m0fnu = float8_internal::float8_e8m0fnu; } // namespace ml_dtypes // Work-around for isinf/isnan/isfinite issue on aarch64. namespace Eigen { namespace internal { template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e3m4& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e4m3& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e4m3fn& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e4m3b11fnuz& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e4m3fnuz& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e5m2& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e5m2fnuz& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e8m0fnu& x) { return ml_dtypes::float8_internal::isinf(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e3m4& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e4m3& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e4m3fn& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e4m3b11fnuz& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e4m3fnuz& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e5m2& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e5m2fnuz& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e8m0fnu& x) { return ml_dtypes::float8_internal::isnan(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e3m4& x) { return ml_dtypes::float8_internal::isfinite(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e4m3& x) { return ml_dtypes::float8_internal::isfinite(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e4m3fn& x) { return ml_dtypes::float8_internal::isfinite(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e4m3b11fnuz& x) { return ml_dtypes::float8_internal::isfinite(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e4m3fnuz& x) { return ml_dtypes::float8_internal::isfinite(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e5m2& x) { return ml_dtypes::float8_internal::isfinite(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e5m2fnuz& x) { return ml_dtypes::float8_internal::isfinite(x); } template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e8m0fnu& x) { return ml_dtypes::float8_internal::isfinite(x); } } // namespace internal } // namespace Eigen #endif // ML_DTYPES_FLOAT8_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/include/intn.h000066400000000000000000000251331510671665600221050ustar00rootroot00000000000000/* Copyright 2023 The ml_dtypes Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_INTN_H_ #define ML_DTYPES_INTN_H_ #include #include #include #include #include #include #include namespace ml_dtypes { // Stores the n-bit integer value in the low n bits of a byte. The upper // bits are left unspecified and ignored. template struct intN { private: UnderlyingTy v_; using SignedUnderlyingTy = std::make_signed_t; using UnsignedUnderlyingTy = std::make_unsigned_t; static constexpr int kUnderlyingBits = std::numeric_limits::digits; static_assert( std::is_same_v || std::is_same_v, "The underyling type must be a signed or unsigned 8-bit integer."); // Mask the upper bits. static inline constexpr UnderlyingTy Mask(UnderlyingTy v) { return static_cast( static_cast(v) << (kUnderlyingBits - N)) >> (kUnderlyingBits - N); } // Mask the upper bits and sign-extend for signed types. static inline constexpr UnderlyingTy ExtendToFullWidth(UnderlyingTy v) { return static_cast(static_cast(v) << (kUnderlyingBits - N)) >> (kUnderlyingBits - N); } // Casts to the corresponding UnderlyingTy value. inline constexpr UnderlyingTy IntValue() const { return ExtendToFullWidth(v_); } public: constexpr intN() noexcept : v_(0) {} constexpr intN(const intN& other) noexcept = default; constexpr intN(intN&& other) noexcept = default; constexpr intN& operator=(const intN& other) = default; constexpr intN& operator=(intN&&) = default; explicit constexpr intN(UnderlyingTy val) : v_(Mask(val)) {} template explicit constexpr intN(T t) : intN(static_cast(t)) {} using underlying_type = UnderlyingTy; static constexpr int bits = N; static constexpr int digits = std::is_signed_v ? N - 1 : N; static constexpr intN highest() { return intN((1 << digits) - 1); } static constexpr intN lowest() { return std::is_signed_v ? intN(1) << digits : intN(0); } template explicit constexpr operator T() const { return static_cast(IntValue()); } // NOLINTNEXTLINE(google-explicit-constructor) constexpr operator std::optional() const { return static_cast(IntValue()); } constexpr intN operator-() const { return intN(-v_); } constexpr intN operator+(const intN& other) const { return intN(v_ + other.v_); } constexpr intN operator-(const intN& other) const { return intN(v_ - other.v_); } constexpr intN operator*(const intN& other) const { return intN(v_ * other.v_); } constexpr intN operator/(const intN& other) const { return intN(IntValue() / other.IntValue()); } constexpr intN operator%(const intN& other) const { return intN((IntValue() % other.IntValue())); } constexpr intN operator&(const intN& other) const { return intN(v_ & other.v_); } constexpr intN operator|(const intN& other) const { return intN(v_ | other.v_); } constexpr intN operator^(const intN& other) const { return intN(v_ ^ other.v_); } constexpr intN operator~() const { return intN(~v_); } constexpr intN operator>>(int amount) const { return intN(IntValue() >> amount); } constexpr intN operator<<(int amount) const { return intN(v_ << amount); } constexpr bool operator==(const intN& other) const { return Mask(v_) == Mask(other.v_); } constexpr bool operator!=(const intN& other) const { return Mask(v_) != Mask(other.v_); } constexpr bool operator<(const intN& other) const { return IntValue() < other.IntValue(); } constexpr bool operator>(const intN& other) const { return IntValue() > other.IntValue(); } constexpr bool operator<=(const intN& other) const { return IntValue() <= other.IntValue(); } constexpr bool operator>=(const intN& other) const { return IntValue() >= other.IntValue(); } constexpr bool operator==(int64_t other) const { return IntValue() == other; } constexpr bool operator!=(int64_t other) const { return IntValue() != other; } constexpr bool operator<(int64_t other) const { return IntValue() < other; } constexpr bool operator>(int64_t other) const { return IntValue() > other; } constexpr bool operator<=(int64_t other) const { return IntValue() <= other; } constexpr bool operator>=(int64_t other) const { return IntValue() >= other; } friend constexpr bool operator==(int64_t a, const intN& b) { return a == b.IntValue(); } friend constexpr bool operator!=(int64_t a, const intN& b) { return a != b.IntValue(); } friend constexpr bool operator<(int64_t a, const intN& b) { return a < b.IntValue(); } friend constexpr bool operator>(int64_t a, const intN& b) { return a > b.IntValue(); } friend constexpr bool operator<=(int64_t a, const intN& b) { return a <= b.IntValue(); } friend constexpr bool operator>=(int64_t a, const intN& b) { return a >= b.IntValue(); } constexpr intN& operator++() { v_ = Mask(v_ + 1); return *this; } constexpr intN operator++(int) { intN orig = *this; this->operator++(); return orig; } constexpr intN& operator--() { v_ = Mask(v_ - 1); return *this; } constexpr intN operator--(int) { intN orig = *this; this->operator--(); return orig; } constexpr intN& operator+=(const intN& other) { *this = *this + other; return *this; } constexpr intN& operator-=(const intN& other) { *this = *this - other; return *this; } constexpr intN& operator*=(const intN& other) { *this = *this * other; return *this; } constexpr intN& operator/=(const intN& other) { *this = *this / other; return *this; } constexpr intN& operator%=(const intN& other) { *this = *this % other; return *this; } constexpr intN& operator&=(const intN& other) { *this = *this & other; return *this; } constexpr intN& operator|=(const intN& other) { *this = *this | other; return *this; } constexpr intN& operator^=(const intN& other) { *this = *this ^ other; return *this; } constexpr intN& operator>>=(int amount) { *this = *this >> amount; return *this; } constexpr intN& operator<<=(int amount) { *this = *this << amount; return *this; } friend ::std::ostream& operator<<(::std::ostream& os, const intN& num) { os << static_cast(num); return os; } std::string ToString() const { std::ostringstream os; os << static_cast(*this); return os.str(); } }; using int1 = intN<1, int8_t>; using int2 = intN<2, int8_t>; using uint1 = intN<1, uint8_t>; using uint2 = intN<2, uint8_t>; using int4 = intN<4, int8_t>; using uint4 = intN<4, uint8_t>; namespace internal { template struct intN_numeric_limits_base { static inline constexpr const bool is_specialized = true; static inline constexpr const bool is_integer = true; static inline constexpr const bool is_exact = true; static inline constexpr const bool has_infinity = false; static inline constexpr const bool has_quiet_NaN = false; static inline constexpr const bool has_signaling_NaN = false; #if !defined(__cplusplus) || __cplusplus < 202302L static inline constexpr const std::float_denorm_style has_denorm = std::denorm_absent; static inline constexpr const bool has_denorm_loss = false; #endif static inline constexpr const std::float_round_style round_style = std::round_toward_zero; static inline constexpr const bool is_iec559 = false; static inline constexpr const bool is_bounded = true; static inline constexpr const int max_digits10 = 0; // Not used for integers. static inline constexpr const int radix = 2; static inline constexpr const int min_exponent = 0; static inline constexpr const int min_exponent10 = 0; static inline constexpr const int max_exponent = 0; static inline constexpr const int max_exponent10 = 0; static inline constexpr const bool traps = true; static inline constexpr const bool tinyness_before = false; static inline constexpr const bool is_signed = std::is_signed_v; static inline constexpr const bool is_modulo = !is_signed; static inline constexpr const int digits = intN::digits; // floor(digits * log10(2)) static inline constexpr const int digits10 = (digits * 3) / 10; static constexpr intN epsilon() noexcept { return intN(0); } static constexpr intN round_error() noexcept { return intN(0); } static constexpr intN infinity() noexcept { return intN(0); } static constexpr intN quiet_NaN() noexcept { return intN(0); } static constexpr intN signaling_NaN() noexcept { return intN(0); } static constexpr intN denorm_min() noexcept { return intN(0); } static constexpr intN min() noexcept { return intN::lowest(); } static constexpr intN lowest() noexcept { return intN::lowest(); } static constexpr intN max() noexcept { return intN::highest(); } }; } // namespace internal } // namespace ml_dtypes namespace std { template <> struct numeric_limits : public ml_dtypes::internal::intN_numeric_limits_base {}; template <> struct numeric_limits : public ml_dtypes::internal::intN_numeric_limits_base {}; template <> struct numeric_limits : public ml_dtypes::internal::intN_numeric_limits_base {}; template <> struct numeric_limits : public ml_dtypes::internal::intN_numeric_limits_base {}; template <> struct numeric_limits : public ml_dtypes::internal::intN_numeric_limits_base {}; template <> struct numeric_limits : public ml_dtypes::internal::intN_numeric_limits_base {}; } // namespace std #endif // ML_DTYPES_INTN_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/include/mxfloat.h000066400000000000000000000314201510671665600226030ustar00rootroot00000000000000/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_MXFLOAT_H_ #define ML_DTYPES_MXFLOAT_H_ // Microscaling (MX) floating point formats, as described in // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf // // Note: this implements the underlying raw data types (e.g. E2M1FN), not the // composite types (e.g. MXFP4). #include #include #include "ml_dtypes/include/float8.h" #include "Eigen/Core" namespace ml_dtypes { namespace mxfloat_internal { // Use 8-bit storage for 6-bit and 4-bit types. template class mxfloat6_base : public float8_internal::float8_base { using Base = float8_internal::float8_base; friend class float8_internal::float8_base; using Base::Base; public: static constexpr int kBits = 6; explicit EIGEN_DEVICE_FUNC operator bool() const { return (Base::rep() & 0x1F) != 0; } constexpr Derived operator-() const { return Derived::FromRep(Base::rep() ^ 0x20); } Derived operator-(const Derived& other) const { return Base::operator-(other); } }; template class mxfloat4_base : public float8_internal::float8_base { using Base = float8_internal::float8_base; friend class float8_internal::float8_base; using Base::Base; public: static constexpr int kBits = 4; explicit EIGEN_DEVICE_FUNC operator bool() const { return (Base::rep() & 0x07) != 0; } constexpr Derived operator-() const { return Derived::FromRep(Base::rep() ^ 0x08); } Derived operator-(const Derived& other) const { return Base::operator-(other); } }; class float6_e2m3fn : public mxfloat6_base { // Exponent: 2, Mantissa: 3, bias: 1. // Extended range: no inf, no NaN. using Base = mxfloat6_base; friend class float8_internal::float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float6_e2m3fn(T f8) : float6_e2m3fn(ConvertFrom(f8)) {} }; class float6_e3m2fn : public mxfloat6_base { // Exponent: 3, Mantissa: 2, bias: 3. // Extended range: no inf, no NaN. using Base = mxfloat6_base; friend class float8_internal::float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float6_e3m2fn(T f8) : float6_e3m2fn(ConvertFrom(f8)) {} }; class float4_e2m1fn : public mxfloat4_base { // Exponent: 2, Mantissa: 1, bias: 1. // Extended range: no inf, no NaN. using Base = mxfloat4_base; friend class float8_internal::float8_base; using Base::Base; public: template = 0> explicit EIGEN_DEVICE_FUNC float4_e2m1fn(T f8) : float4_e2m1fn(ConvertFrom(f8)) {} }; // Common properties for specializing std::numeric_limits. template struct numeric_limits_mxfloat_tpl { protected: static constexpr int kExponentBias = (1 << (E - 1)) - 1; static constexpr int kMantissaBits = M; public: // NOLINTBEGIN: these names must match std::numeric_limits. static constexpr bool is_specialized = true; static constexpr bool is_signed = true; static constexpr bool is_integer = false; static constexpr bool is_exact = false; static constexpr bool has_infinity = false; static constexpr bool has_quiet_NaN = false; static constexpr bool has_signaling_NaN = false; #if !defined(__cplusplus) || __cplusplus < 202302L static constexpr std::float_denorm_style has_denorm = std::denorm_present; static constexpr bool has_denorm_loss = false; #endif static constexpr std::float_round_style round_style = std::round_to_nearest; static constexpr bool is_iec559 = false; static constexpr bool is_bounded = true; static constexpr bool is_modulo = false; static constexpr int digits = kMantissaBits + 1; static constexpr int digits10 = float8_internal::Digits10FromDigits(digits); static constexpr int max_digits10 = float8_internal::MaxDigits10FromDigits(digits); static constexpr int radix = std::numeric_limits::radix; static constexpr int min_exponent = (1 - kExponentBias) + 1; static constexpr int min_exponent10 = float8_internal::MinExponent10FromMinExponent(min_exponent); static constexpr int max_exponent = kExponentBias + 2; static constexpr int max_exponent10 = float8_internal::MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); static constexpr bool traps = std::numeric_limits::traps; static constexpr bool tinyness_before = std::numeric_limits::tinyness_before; // NOLINTEND }; struct numeric_limits_float6_e2m3fn : public numeric_limits_mxfloat_tpl<2, 3> { // 1.0 * 2^(0) = 1 static constexpr float6_e2m3fn min() { return float6_e2m3fn::FromRep(0b0'01'000); } // -1.875 * 2^(2) = -7.5 static constexpr float6_e2m3fn lowest() { return float6_e2m3fn::FromRep(0b1'11'111); } // 1.875 * 2^(2) = 7.5 static constexpr float6_e2m3fn max() { return float6_e2m3fn::FromRep(0b0'11'111); } // 0.125 * 2^(0) = 0.125 static constexpr float6_e2m3fn epsilon() { return float6_e2m3fn::FromRep(0b0'00'001); } // 0.25 * 2^(0) = 0.25 static constexpr float6_e2m3fn round_error() { return float6_e2m3fn::FromRep(0b0'00'010); } // 0.25 * 2^(0) = 0.125 static constexpr float6_e2m3fn denorm_min() { return float6_e2m3fn::FromRep(0b0'00'001); } // Conversion from NaNs is implementation-defined (by MX specification). static constexpr float6_e2m3fn quiet_NaN() { return float6_e2m3fn::FromRep(0b1'00'000); } static constexpr float6_e2m3fn signaling_NaN() { return float6_e2m3fn::FromRep(0b1'00'000); } static constexpr float6_e2m3fn infinity() { return float6_e2m3fn::FromRep(0b0'11'111); } }; struct numeric_limits_float6_e3m2fn : public numeric_limits_mxfloat_tpl<3, 2> { // 1.0 * 2^(-2) = 0.25 static constexpr float6_e3m2fn min() { return float6_e3m2fn::FromRep(0b0'001'00); } // -1.75 * 2^(4) = -28 static constexpr float6_e3m2fn lowest() { return float6_e3m2fn::FromRep(0b1'111'11); } // 1.75 * 2^(4) = 28 static constexpr float6_e3m2fn max() { return float6_e3m2fn::FromRep(0b0'111'11); } // 1.0 * 2^(-2) = 0.25 static constexpr float6_e3m2fn epsilon() { return float6_e3m2fn::FromRep(0b0'001'00); } // 1.0 * 2^(0) = 1 static constexpr float6_e3m2fn round_error() { return float6_e3m2fn::FromRep(0b0'011'00); } // 0.25 * 2^(-2) = 0.0625 static constexpr float6_e3m2fn denorm_min() { return float6_e3m2fn::FromRep(0b0'000'01); } // Conversion from NaNs is implementation-defined (by MX specification). static constexpr float6_e3m2fn quiet_NaN() { return float6_e3m2fn::FromRep(0b1'000'00); } static constexpr float6_e3m2fn signaling_NaN() { return float6_e3m2fn::FromRep(0b1'000'00); } static constexpr float6_e3m2fn infinity() { return float6_e3m2fn::FromRep(0b0'111'11); } }; struct numeric_limits_float4_e2m1fn : public numeric_limits_mxfloat_tpl<2, 1> { // 1.0 * 2^(0) = 1 static constexpr float4_e2m1fn min() { return float4_e2m1fn::FromRep(0b0'01'0); } // -1.5 * 2^(2) = -6 static constexpr float4_e2m1fn lowest() { return float4_e2m1fn::FromRep(0b1'11'1); } // 1.5 * 2^(2) = 6 static constexpr float4_e2m1fn max() { return float4_e2m1fn::FromRep(0b0'11'1); } // 0.5 * 2^(0) = 0.5 static constexpr float4_e2m1fn epsilon() { return float4_e2m1fn::FromRep(0b0'00'1); } // 1.0 * 2^(0) = 1 static constexpr float4_e2m1fn round_error() { return float4_e2m1fn::FromRep(0b0'01'0); } // 0.5 * 2^(0) = 0.5 static constexpr float4_e2m1fn denorm_min() { return float4_e2m1fn::FromRep(0b0'00'1); } // Conversion from NaNs is implementation-defined (by MX specification). static constexpr float4_e2m1fn quiet_NaN() { return float4_e2m1fn::FromRep(0b1'00'0); } static constexpr float4_e2m1fn signaling_NaN() { return float4_e2m1fn::FromRep(0b1'00'0); } static constexpr float4_e2m1fn infinity() { return float4_e2m1fn::FromRep(0b0'11'1); } }; // Free-functions for use with ADL and in Eigen. constexpr inline float6_e2m3fn abs(const float6_e2m3fn& a) { return float6_e2m3fn::FromRep(a.rep() & 0b0'11'111); } constexpr inline bool(isnan)(const float6_e2m3fn& a) { return false; } constexpr inline float6_e3m2fn abs(const float6_e3m2fn& a) { return float6_e3m2fn::FromRep(a.rep() & 0b0'111'11); } constexpr inline bool(isnan)(const float6_e3m2fn& a) { return false; } constexpr inline float4_e2m1fn abs(const float4_e2m1fn& a) { return float4_e2m1fn::FromRep(a.rep() & 0b0'11'1); } constexpr inline bool(isnan)(const float4_e2m1fn& a) { return false; } // Define traits required for floating point conversion. template struct TraitsBase : public float8_internal::TraitsBase { static constexpr int kBits = E + M + 1; static constexpr int kMantissaBits = M; static constexpr int kExponentBits = E; static constexpr int kExponentBias = (1 << (E - 1)) - 1; static constexpr uint8_t kExponentMask = ((1 << E) - 1) << M; }; } // namespace mxfloat_internal // Exported types. using float6_e2m3fn = mxfloat_internal::float6_e2m3fn; using float6_e3m2fn = mxfloat_internal::float6_e3m2fn; using float4_e2m1fn = mxfloat_internal::float4_e2m1fn; } // namespace ml_dtypes // Standard library overrides. namespace std { template <> struct numeric_limits : public ml_dtypes::mxfloat_internal::numeric_limits_float6_e2m3fn {}; template <> struct numeric_limits : public ml_dtypes::mxfloat_internal::numeric_limits_float6_e3m2fn {}; template <> struct numeric_limits : public ml_dtypes::mxfloat_internal::numeric_limits_float4_e2m1fn {}; } // namespace std // Conversion traits. namespace ml_dtypes { namespace float8_internal { template <> struct Traits : public mxfloat_internal::TraitsBase {}; template <> struct Traits : public mxfloat_internal::TraitsBase {}; template <> struct Traits : public mxfloat_internal::TraitsBase {}; } // namespace float8_internal } // namespace ml_dtypes // Eigen library overrides. namespace Eigen { namespace numext { #define MXFLOAT_EIGEN_SIGNBIT_IMPL(Type) \ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Type signbit(const Type& x) { \ int8_t t = bit_cast(x) << (8 - Type::kBits); \ return bit_cast(t >> 7); \ } MXFLOAT_EIGEN_SIGNBIT_IMPL(ml_dtypes::float6_e2m3fn) MXFLOAT_EIGEN_SIGNBIT_IMPL(ml_dtypes::float6_e3m2fn) MXFLOAT_EIGEN_SIGNBIT_IMPL(ml_dtypes::float4_e2m1fn) #undef MXFLOAT_EIGEN_SIGNBIT_IMPL } // namespace numext // Work-around for isinf/isnan/isfinite issue on aarch64. namespace internal { #define MXFLOAT_EIGEN_ISFINITE_IMPL(Type) \ template <> \ EIGEN_DEVICE_FUNC inline bool isinf_impl(const Type&) { \ return false; \ } \ template <> \ EIGEN_DEVICE_FUNC inline bool isnan_impl(const Type&) { \ return false; \ } \ template <> \ EIGEN_DEVICE_FUNC inline bool isfinite_impl(const Type&) { \ return true; \ } MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float6_e2m3fn) MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float6_e3m2fn) MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float4_e2m1fn) #undef MXFLOAT_EIGEN_ISFINITE_IMPL } // namespace internal } // namespace Eigen #endif // ML_DTYPES_MXFLOAT_H_ jax-ml-ml_dtypes-882eb0f/ml_dtypes/py.typed000066400000000000000000000000001510671665600210220ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/000077500000000000000000000000001510671665600204775ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/conftest.py000066400000000000000000000014201510671665600226730ustar00rootroot00000000000000# Copyright 2024 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """pytest configuration file.""" import pathlib import sys # Add ml_dtypes/tests folder to discover multi_thread_utils.py module sys.path.insert(0, str(pathlib.Path(__file__).absolute().parent)) jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/custom_float_test.py000066400000000000000000001130101510671665600246030ustar00rootroot00000000000000# Copyright 2022 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test cases for custom floating point types.""" import collections import contextlib import copy import itertools import math import pickle import sys from typing import Type import warnings from absl.testing import absltest from absl.testing import parameterized import ml_dtypes from multi_thread_utils import multi_threaded import numpy as np bfloat16 = ml_dtypes.bfloat16 float4_e2m1fn = ml_dtypes.float4_e2m1fn float6_e2m3fn = ml_dtypes.float6_e2m3fn float6_e3m2fn = ml_dtypes.float6_e3m2fn float8_e3m4 = ml_dtypes.float8_e3m4 float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz float8_e5m2 = ml_dtypes.float8_e5m2 float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz float8_e8m0fnu = ml_dtypes.float8_e8m0fnu try: # numpy >= 2.0 ComplexWarning = np.exceptions.ComplexWarning except AttributeError: # numpy < 2.0 ComplexWarning = np.ComplexWarning @contextlib.contextmanager def ignore_warning(**kw): with warnings.catch_warnings(): warnings.filterwarnings("ignore", **kw) yield def numpy_assert_allclose(a, b, float_type, **kwargs): a = a.astype(np.float32) if a.dtype == float_type else a b = b.astype(np.float32) if b.dtype == float_type else b return np.testing.assert_allclose(a, b, **kwargs) def numpy_promote_types( a: Type[np.generic], b: Type[np.generic], float_type: Type[np.generic], next_largest_fp_type: Type[np.generic], ) -> Type[np.generic]: if a == float_type and b == float_type: return float_type if a == float_type: a = next_largest_fp_type if b == float_type: b = next_largest_fp_type return np.promote_types(a, b) def truncate(x, float_type): if isinstance(x, np.ndarray): return x.astype(float_type).astype(np.float32) else: return type(x)(float_type(x)) def binary_operation_test(a, b, op, float_type): a = float_type(a) b = float_type(b) expected = op(np.float32(a), np.float32(b)) result = op(a, b) if math.isnan(expected): if dtype_has_nan(float_type) and not math.isnan(result): raise AssertionError("%s expected to be nan." % repr(result)) else: np.testing.assert_equal( truncate(expected, float_type=float_type), float(result) ) def dtype_has_inf(dtype): """Determines if the dtype has an `inf` representation.""" try: return np.isinf(dtype(float("inf"))) except (OverflowError, ValueError): return False def dtype_has_nan(dtype): """Determines if the dtype has an `nan` representation.""" try: return np.isnan(dtype(float("nan"))) except (OverflowError, ValueError): return False def dtype_is_signed(dtype): """Determines if the floating dtype has a sign bit.""" return ml_dtypes.finfo(dtype).min < 0 FLOAT_DTYPES = [ bfloat16, float4_e2m1fn, float6_e2m3fn, float6_e3m2fn, float8_e3m4, float8_e4m3, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, float8_e5m2, float8_e5m2fnuz, float8_e8m0fnu, ] NUMPY_DTYPES = [ np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, ] # Values that should round trip exactly to float and back. # pylint: disable=g-complex-comprehension FLOAT_VALUES = { dtype: [ 0.0, 1.0, -1.0, 0.5, -0.5, float(ml_dtypes.finfo(dtype).eps), 1.0 + float(ml_dtypes.finfo(dtype).eps), 1.0 - float(ml_dtypes.finfo(dtype).eps), -1.0 - float(ml_dtypes.finfo(dtype).eps), -1.0 + float(ml_dtypes.finfo(dtype).eps), 3.5, 4, 5, 7, float(ml_dtypes.finfo(dtype).max), -float(ml_dtypes.finfo(dtype).max), float("nan") if dtype_has_nan(dtype) else 0.0, float("-nan") if dtype_has_nan(dtype) else 0.0, float("inf") if dtype_has_inf(dtype) else 0.0, float("-inf") if dtype_has_inf(dtype) else 0.0, ] for dtype in FLOAT_DTYPES } # E8M0 specific values FLOAT_VALUES[float8_e8m0fnu] = [ 0.125, 1.0, 0.5, 1.0 + float(ml_dtypes.finfo(float8_e8m0fnu).eps), 4, float(ml_dtypes.finfo(float8_e8m0fnu).max), float("nan"), ] # Remove values unsupported by some types. FLOAT_VALUES[float4_e2m1fn] = [ x for x in FLOAT_VALUES[float4_e2m1fn] if x not in {3.5, 5, 7} ] # Values that should round trip exactly to integer and back. INT_VALUES = { bfloat16: [0, 1, 2, 10, 34, 47, 128, 255, 256, 512], float4_e2m1fn: [0, 1, 2, 3, 4, 6], float6_e2m3fn: [0, 1, 2, 3, 4, 5, 6, 7], float6_e3m2fn: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28], float8_e3m4: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(4) ) ), float8_e4m3: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(8) ) ), float8_e4m3b11fnuz: [*range(16), *range(16, 30, 2)], float8_e4m3fn: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(9) ) )[:-1], float8_e4m3fnuz: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(8) ) )[:-1], float8_e5m2: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 2)) for n in range(16) ) ), float8_e5m2fnuz: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 2)) for n in range(16) ) ), float8_e8m0fnu: [1, 2, 256], } # pylint: disable=g-complex-comprehension @multi_threaded( num_workers=3, skip_tests=[ "testDiv", "testPickleable", "testRoundTripNumpyTypes", "testRoundTripToNumpy", "testConstructFromDtype", "testHashNumbers", "testHashNan", ], ) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} for dtype in FLOAT_DTYPES ) ) class CustomFloatTest(parameterized.TestCase): """Tests the non-numpy Python methods of the custom float type.""" def testModuleName(self, float_type): self.assertEqual(float_type.__module__, "ml_dtypes") @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testPickleable(self, float_type): # https://github.com/jax-ml/jax/discussions/8505 x = np.arange(10, dtype=float_type) serialized = pickle.dumps(x) x_out = pickle.loads(serialized) self.assertEqual(x_out.dtype, x.dtype) np.testing.assert_array_equal(x_out.astype("float32"), x.astype("float32")) def testRoundTripToFloat(self, float_type): for v in FLOAT_VALUES[float_type]: np.testing.assert_equal(v, float(float_type(v))) @ignore_warning(category=RuntimeWarning, message="overflow encountered") def testRoundTripNumpyTypes(self, float_type): for dtype in [np.float16, np.float32, np.float64, np.longdouble]: for f in FLOAT_VALUES[float_type]: # Ignore values converting to NaN/Inf if np.abs(f) > np.finfo(dtype).max: continue np.testing.assert_equal(dtype(f), dtype(float_type(dtype(f)))) np.testing.assert_equal(float(dtype(f)), float(float_type(dtype(f)))) np.testing.assert_equal(dtype(f), dtype(float_type(np.array(f, dtype)))) np.testing.assert_equal( dtype(np.array(FLOAT_VALUES[float_type], float_type)), np.array(FLOAT_VALUES[float_type], dtype), ) def testRoundTripToInt(self, float_type): for v in INT_VALUES[float_type]: self.assertEqual(v, int(float_type(v))) if dtype_is_signed(float_type): self.assertEqual(-v, int(float_type(-v))) @ignore_warning(category=RuntimeWarning, message="overflow encountered") def testRoundTripToNumpy(self, float_type): for dtype in [ float_type, np.float16, np.float32, np.float64, np.longdouble, ]: with self.subTest(dtype.__name__): for v in FLOAT_VALUES[float_type]: if np.abs(v) > ml_dtypes.finfo(dtype).max: continue np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v)))) np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v)))) np.testing.assert_equal( dtype(v), dtype(float_type(np.array(v, dtype))) ) if ( dtype != float_type and ml_dtypes.finfo(float_type).max <= ml_dtypes.finfo(dtype).max ): np.testing.assert_equal( np.array(FLOAT_VALUES[float_type], dtype), float_type(np.array(FLOAT_VALUES[float_type], dtype)).astype( dtype ), ) def testCastBetweenCustomTypes(self, float_type): for dtype in FLOAT_DTYPES: # float8_e8m0 only registering cast <=> bfloat16 if ( float_type == float8_e8m0fnu or dtype == float8_e8m0fnu ) and dtype != bfloat16: continue x = np.array(FLOAT_VALUES[float_type], dtype=dtype) y = x.astype(float_type) z = x.astype(float).astype(float_type) numpy_assert_allclose(y, z, float_type=float_type) def testStr(self, float_type): for value in FLOAT_VALUES[float_type]: self.assertEqual( "%.6g" % float(float_type(value)), str(float_type(value)) ) def testFromStr(self, float_type): self.assertEqual(float_type(1.2), float_type("1.2")) if dtype_has_nan(float_type): self.assertTrue(np.isnan(float_type("nan"))) self.assertTrue(np.isnan(float_type("-nan"))) if dtype_has_inf(float_type): self.assertEqual(float_type(float("inf")), float_type("inf")) self.assertEqual(float_type(float("-inf")), float_type("-inf")) def testRepr(self, float_type): for value in FLOAT_VALUES[float_type]: self.assertEqual( "%.6g" % float(float_type(value)), repr(float_type(value)) ) def testItem(self, float_type): self.assertIsInstance(float_type(0).item(), float) def testHashZero(self, float_type): """Tests that negative zero and zero hash to the same value.""" if float_type == float8_e8m0fnu: raise self.skipTest("Skip hash zero test for E8M0 datatype.") self.assertEqual(hash(float_type(-0.0)), hash(float_type(0.0))) def testHashNumbers(self, float_type): for value in np.extract( np.isfinite(FLOAT_VALUES[float_type]), FLOAT_VALUES[float_type] ): with self.subTest(value): self.assertEqual(hash(value), hash(float_type(value)), str(value)) def testHashNan(self, float_type): for name, nan in [ ("PositiveNan", float_type(float("nan"))), ("NegativeNan", float_type(float("-nan"))), ]: with self.subTest(name): nan_hash = hash(nan) nan_object_hash = object.__hash__(nan) # The hash of a NaN is either 0 or a hash of the object pointer. self.assertIn(nan_hash, (sys.hash_info.nan, nan_object_hash), str(nan)) def testHashInf(self, float_type): if dtype_has_inf(float_type): self.assertEqual(sys.hash_info.inf, hash(float_type(float("inf"))), "inf") self.assertEqual( -sys.hash_info.inf, hash(float_type(float("-inf"))), "-inf" ) # Tests for Python operations def testNegate(self, float_type): for v in FLOAT_VALUES[float_type]: np.testing.assert_equal( float(float_type(-float(float_type(v)))), float(-float_type(v)) ) def testAdd(self, float_type): for a, b in [ (0, 0), (1, 0), (1, -1), (2, 3.5), (3.5, -2.25), (float("inf"), -2.25), (float("-inf"), -2.25), (3.5, float("nan")), ]: binary_operation_test(a, b, op=lambda a, b: a + b, float_type=float_type) def testAddScalarTypePromotion(self, float_type): """Tests type promotion against Numpy scalar values.""" types = [float_type, np.float16, np.float32, np.float64, np.longdouble] for lhs_type in types: for rhs_type in types: expected_type = numpy_promote_types( lhs_type, rhs_type, float_type=float_type, next_largest_fp_type=np.float32, ) actual_type = type(lhs_type(3.5) + rhs_type(2.25)) self.assertEqual(expected_type, actual_type) def testAddArrayTypePromotion(self, float_type): self.assertEqual( np.float32, type(float_type(3.5) + np.array(2.25, np.float32)) ) self.assertEqual( np.float32, type(np.array(3.5, np.float32) + float_type(2.25)) ) def testSub(self, float_type): for a, b in [ (0, 0), (1, 0), (1, -1), (2, 3.5), (3.5, -2.25), (-2.25, float("inf")), (-2.25, float("-inf")), (3.5, float("nan")), ]: binary_operation_test(a, b, op=lambda a, b: a - b, float_type=float_type) def testMul(self, float_type): for a, b in [ (0, 0), (1, 0), (1, -1), (3.5, -2.25), (float("inf"), -2.25), (float("-inf"), -2.25), (3.5, float("nan")), ]: binary_operation_test(a, b, op=lambda a, b: a * b, float_type=float_type) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") def testDiv(self, float_type): for a, b in [ (0, 0), (1, 0), (1, -1), (2, 3.5), (3.5, -2.25), (float("inf"), -2.25), (float("-inf"), -2.25), (3.5, float("nan")), ]: binary_operation_test(a, b, op=lambda a, b: a / b, float_type=float_type) def testLess(self, float_type): for v in FLOAT_VALUES[float_type]: for w in FLOAT_VALUES[float_type]: result = float_type(v) < float_type(w) self.assertEqual(v < w, result) self.assertIsInstance(result, np.bool_) def testLessEqual(self, float_type): for v in FLOAT_VALUES[float_type]: for w in FLOAT_VALUES[float_type]: result = float_type(v) <= float_type(w) self.assertEqual(v <= w, result) self.assertIsInstance(result, np.bool_) def testGreater(self, float_type): for v in FLOAT_VALUES[float_type]: for w in FLOAT_VALUES[float_type]: result = float_type(v) > float_type(w) self.assertEqual(v > w, result) self.assertIsInstance(result, np.bool_) def testGreaterEqual(self, float_type): for v in FLOAT_VALUES[float_type]: for w in FLOAT_VALUES[float_type]: result = float_type(v) >= float_type(w) self.assertEqual(v >= w, result) self.assertIsInstance(result, np.bool_) def testEqual(self, float_type): for v in FLOAT_VALUES[float_type]: for w in FLOAT_VALUES[float_type]: result = float_type(v) == float_type(w) self.assertEqual(v == w, result) self.assertIsInstance(result, np.bool_) def testNotEqual(self, float_type): for v in FLOAT_VALUES[float_type]: for w in FLOAT_VALUES[float_type]: result = float_type(v) != float_type(w) self.assertEqual(v != w, result) self.assertIsInstance(result, np.bool_) def testNan(self, float_type): if not dtype_has_nan(float_type): self.skipTest("no NaN encoding") a = np.isnan(float_type(float("nan"))) self.assertTrue(a) numpy_assert_allclose( np.array([1.0, a]), np.array([1.0, a]), float_type=float_type ) a = np.array( [float_type(1.34375), float_type(1.4375), float_type(float("nan"))], dtype=float_type, ) b = np.array( [float_type(1.3359375), float_type(1.4375), float_type(float("nan"))], dtype=float_type, ) numpy_assert_allclose( a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True, float_type=float_type, ) def testSort(self, float_type): # Note: np.sort doesn't work properly with NaNs since they always compare # False. values_to_sort = np.float32( [x for x in FLOAT_VALUES[float_type] if not np.isnan(x)] ) sorted_f32 = np.sort(values_to_sort) sorted_float_type = np.sort(values_to_sort.astype(float_type)) # pylint: disable=too-many-function-args np.testing.assert_equal(sorted_f32, np.float32(sorted_float_type)) def testArgmax(self, float_type): values_to_sort = np.float32( float_type(np.float32(FLOAT_VALUES[float_type])) ) argmax_f32 = np.argmax(values_to_sort) argmax_float_type = np.argmax(values_to_sort.astype(float_type)) # pylint: disable=too-many-function-args np.testing.assert_equal(argmax_f32, argmax_float_type) def testArgmaxOnNan(self, float_type): """Ensures we return the right thing for multiple NaNs.""" if not dtype_has_nan(float_type): self.skipTest("no NaN encoding") one_with_nans = np.array( [1.0, float("nan"), float("nan")], dtype=np.float32 ) np.testing.assert_equal( np.argmax(one_with_nans.astype(float_type)), np.argmax(one_with_nans) ) def testArgmaxOnNegativeInfinity(self, float_type): """Ensures we return the right thing for negative infinities.""" inf = np.array([float("-inf")], dtype=np.float32) np.testing.assert_equal(np.argmax(inf.astype(float_type)), np.argmax(inf)) def testArgmin(self, float_type): values_to_sort = np.float32( float_type(np.float32(FLOAT_VALUES[float_type])) ) argmin_f32 = np.argmin(values_to_sort) argmin_float_type = np.argmin(values_to_sort.astype(float_type)) # pylint: disable=too-many-function-args np.testing.assert_equal(argmin_f32, argmin_float_type) def testArgminOnNan(self, float_type): """Ensures we return the right thing for multiple NaNs.""" one_with_nans = np.array( [1.0, float("nan"), float("nan")], dtype=np.float32 ) np.testing.assert_equal( np.argmin(one_with_nans.astype(float_type)), np.argmin(one_with_nans) ) def testArgminOnPositiveInfinity(self, float_type): """Ensures we return the right thing for positive infinities.""" inf = np.array([float("inf")], dtype=np.float32) np.testing.assert_equal(np.argmin(inf.astype(float_type)), np.argmin(inf)) def testDtypeFromString(self, float_type): assert np.dtype(float_type.__name__) == np.dtype(float_type) def testIssubdtype(self, float_type): # In the future, we may want to make these more specific (e.g. use # np.number or np.floating instead of np.generic) by changing the # base in RegisterFloatDtype. self.assertTrue(np.issubdtype(float_type, np.generic)) self.assertTrue(np.issubdtype(np.dtype(float_type), np.generic)) def testCastToDtype(self, float_type): name = float_type.__name__ dt = np.dtype(float_type) self.assertIs(dt.type, float_type) self.assertEqual(dt.name, name) self.assertEqual(repr(dt), f"dtype({name})") def testConstructFromDtype(self, float_type): for np_dtype in NUMPY_DTYPES: with self.subTest(np_dtype.__name__): expected = float_type(1) actual = float_type(np_dtype(1)) self.assertEqual(type(expected), type(actual)) self.assertEqual(float(expected), float(actual)) def testByteSwap(self, float_type): """Test that byteswap works correctly.""" arr = np.array([1.0, 2.0, 3.0], dtype=float_type) original_bytes = arr.tobytes() # Test copy byteswap swapped = arr.byteswap(inplace=False) self.assertIsNot(swapped, arr) # Different object self.assertEqual(arr.tobytes(), original_bytes) # Original unchanged if np.dtype(float_type).itemsize == 2: # 16-bit types should swap bytes self.assertNotEqual(original_bytes, swapped.tobytes()) # Test in-place byteswap arr_copy = arr.copy() result = arr_copy.byteswap(inplace=True) self.assertIs(result, arr_copy) # Same object self.assertEqual(arr_copy.tobytes(), swapped.tobytes()) # Same bytes # Double swap restores original arr_copy.byteswap(inplace=True) self.assertEqual(arr_copy.tobytes(), original_bytes) else: # 8-bit types should be unchanged self.assertEqual(original_bytes, swapped.tobytes()) BinaryOp = collections.namedtuple("BinaryOp", ["op"]) UNARY_UFUNCS = [ np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign, np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p, np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan, np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh, np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc, ] BINARY_UFUNCS = [ np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2, np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2, np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign, ] BINARY_PREDICATE_UFUNCS = [ np.equal, np.not_equal, np.less, np.greater, np.less_equal, np.greater_equal, np.logical_and, np.logical_or, np.logical_xor, ] # pylint: disable=g-complex-comprehension @multi_threaded( num_workers=3, skip_tests=[ "testBinaryPredicateUfunc", "testBinaryUfunc", "testCanCast", "testCasts", "testConformNumpyComplex", "testCopySign", # pytest 9.0.1's subtest appears not to be thread-safe "testDivmod", "testDivmodCornerCases", "testFloordivCornerCases", "testFrexp", "testLdexp", "testModf", "testPredicateUfunc", "testSpacing", "testUnaryUfunc", ], ) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} for dtype in FLOAT_DTYPES ) ) class CustomFloatNumPyTest(parameterized.TestCase): """Tests NumPy integration of the custom float types.""" def testDtype(self, float_type): self.assertEqual(float_type, np.dtype(float_type)) def testHash(self, float_type): h = hash(np.dtype(float_type)) self.assertEqual(h, hash(np.dtype(float_type.dtype))) self.assertEqual(h, hash(np.dtype(float_type.__name__))) def testDeepCopyDoesNotAlterHash(self, float_type): # For context, see https://github.com/jax-ml/jax/issues/4651. If the hash # value of the type descriptor is not initialized correctly, a deep copy # can change the type hash. dtype = np.dtype(float_type) h = hash(dtype) _ = copy.deepcopy(dtype) self.assertEqual(h, hash(dtype)) def testArray(self, float_type): x = np.array([[1, 2, 4]], dtype=float_type) self.assertEqual(float_type, x.dtype) self.assertEqual("[[1 2 4]]", str(x)) np.testing.assert_equal(x, x) numpy_assert_allclose(x, x, float_type=float_type) self.assertTrue((x == x).all()) def testComparisons(self, float_type): x0, x1, y0 = 6, 1, 3 x = np.array([x0, x1, -x0], dtype=np.float32) y = np.array([y0, x1, 0], dtype=np.float32) if float_type == float8_e8m0fnu: x = np.array([30, 7, 1], dtype=np.float32) y = np.array([17, 7, 0.125], dtype=np.float32) bx = x.astype(float_type) by = y.astype(float_type) np.testing.assert_equal(x == y, bx == by) np.testing.assert_equal(x != y, bx != by) np.testing.assert_equal(x < y, bx < by) np.testing.assert_equal(x > y, bx > by) np.testing.assert_equal(x <= y, bx <= by) np.testing.assert_equal(x >= y, bx >= by) def testEqual2(self, float_type): a = np.array([7], float_type) b = np.array([3], float_type) self.assertFalse(a.__eq__(b)) def testCanCast(self, float_type): allowed_casts = [ (np.bool_, float_type), (np.int8, float_type), (np.uint8, float_type), (float_type, np.float32), (float_type, np.float64), (float_type, np.longdouble), (float_type, np.complex64), (float_type, np.complex128), (float_type, np.clongdouble), ] all_dtypes = [ np.float16, np.float32, np.float64, np.longdouble, np.int8, np.int16, np.int32, np.int64, np.complex64, np.complex128, np.clongdouble, np.uint8, np.uint16, np.uint32, np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong, ] for d in all_dtypes: with self.subTest(d.__name__): self.assertEqual( (float_type, d) in allowed_casts, np.can_cast(float_type, d) ) self.assertEqual( (d, float_type) in allowed_casts, np.can_cast(d, float_type) ) @ignore_warning( category=RuntimeWarning, message="invalid value encountered in cast" ) def testCasts(self, float_type): for dtype in [ np.float16, np.float32, np.float64, np.longdouble, np.int8, np.int16, np.int32, np.int64, np.complex64, np.complex128, np.clongdouble, np.uint8, np.uint16, np.uint32, np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong, ]: x = np.array([[1, 2, 4]], dtype=dtype) y = x.astype(float_type) z = y.astype(dtype) self.assertTrue(np.all(x == y)) self.assertEqual(float_type, y.dtype) self.assertTrue(np.all(x == z)) self.assertEqual(dtype, z.dtype) @ignore_warning(category=ComplexWarning) def testConformNumpyComplex(self, float_type): for dtype in [np.complex64, np.complex128, np.clongdouble]: x = np.array([0.5, 1.0 + 2.0j, 4.0], dtype=dtype) y_np = x.astype(np.float32) y_tf = x.astype(float_type) numpy_assert_allclose(y_np, y_tf, atol=2e-2, float_type=float_type) z_np = y_np.astype(dtype) z_tf = y_tf.astype(dtype) numpy_assert_allclose(z_np, z_tf, atol=2e-2, float_type=float_type) def testArange(self, float_type): np.testing.assert_equal( np.arange(1, 100, dtype=np.float32).astype(float_type), np.arange(1, 100, dtype=float_type), ) if float_type == float8_e8m0fnu: raise self.skipTest("Skip negative ranges for E8M0.") np.testing.assert_equal( np.arange(-6, 6, 2, dtype=np.float32).astype(float_type), np.arange(-6, 6, 2, dtype=float_type), ) np.testing.assert_equal( np.arange(-0.0, -2.0, -0.5, dtype=np.float32).astype(float_type), np.arange(-0.0, -2.0, -0.5, dtype=float_type), ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") def testUnaryUfunc(self, float_type): for op in UNARY_UFUNCS: with self.subTest(op.__name__): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7, 10).astype(float_type) numpy_assert_allclose( op(x).astype(np.float32), truncate(op(x.astype(np.float32)), float_type=float_type), rtol=1e-4, float_type=float_type, ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") def testBinaryUfunc(self, float_type): for op in BINARY_UFUNCS: with self.subTest(op.__name__): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7, 10).astype(float_type) y = rng.randn(4, 1, 7, 10).astype(float_type) numpy_assert_allclose( op(x, y).astype(np.float32), truncate( op(x.astype(np.float32), y.astype(np.float32)), float_type=float_type, ), rtol=1e-4, float_type=float_type, ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testBinaryPredicateUfunc(self, float_type): for op in BINARY_PREDICATE_UFUNCS: with self.subTest(op.__name__): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) y = rng.randn(4, 1, 7).astype(float_type) np.testing.assert_equal( op(x, y), op(x.astype(np.float32), y.astype(np.float32)) ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testPredicateUfunc(self, float_type): for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]: with self.subTest(op.__name__): rng = np.random.RandomState(seed=42) shape = (3, 7, 10) posinf_flips = rng.rand(*shape) < 0.1 neginf_flips = rng.rand(*shape) < 0.1 nan_flips = rng.rand(*shape) < 0.1 vals = rng.randn(*shape) vals = np.where(posinf_flips, np.inf, vals) vals = np.where(neginf_flips, -np.inf, vals) vals = np.where(nan_flips, np.nan, vals) vals = vals.astype(float_type) np.testing.assert_equal(op(vals), op(vals.astype(np.float32))) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testDivmod(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) y = rng.randn(4, 1, 7).astype(float_type) x = np.where(np.isfinite(x), x, float_type(1)) y = np.where(np.isfinite(y), y, float_type(1)) y = np.where(y == 0, float_type(1), y) o1, o2 = np.divmod(x, y) e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32)) numpy_assert_allclose( o1, truncate(e1, float_type=float_type), rtol=1e-2, float_type=float_type, ) numpy_assert_allclose( o2, truncate(e2, float_type=float_type), rtol=1e-2, float_type=float_type, ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") def testDivmodCornerCases(self, float_type): x = np.array( [-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan], dtype=float_type, ) xf32 = x.astype("float32") out = np.divmod.outer(x, x) expected = np.divmod.outer(xf32, xf32) numpy_assert_allclose( out[0], truncate(expected[0], float_type=float_type), rtol=0.0, float_type=float_type, ) numpy_assert_allclose( out[1], truncate(expected[1], float_type=float_type), rtol=0.0, float_type=float_type, ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") def testFloordivCornerCases(self, float_type): # Regression test for https://github.com/jax-ml/ml_dtypes/issues/170 x = np.array( [-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan], dtype=float_type, ) xf32 = x.astype("float32") out = np.floor_divide.outer(x, x) expected = np.floor_divide.outer(xf32, xf32) numpy_assert_allclose( out, truncate(expected, float_type=float_type), rtol=0.0, float_type=float_type, ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testModf(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) o1, o2 = np.modf(x) e1, e2 = np.modf(x.astype(np.float32)) numpy_assert_allclose( o1.astype(np.float32), truncate(e1, float_type=float_type), rtol=1e-2, float_type=float_type, ) numpy_assert_allclose( o2.astype(np.float32), truncate(e2, float_type=float_type), rtol=1e-2, float_type=float_type, ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testLdexp(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) y = rng.randint(-50, 50, (1, 7)).astype(np.int32) self.assertEqual(np.ldexp(x, y).dtype, x.dtype) numpy_assert_allclose( np.ldexp(x, y).astype(np.float32), truncate(np.ldexp(x.astype(np.float32), y), float_type=float_type), rtol=1e-2, atol=1e-6, float_type=float_type, ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testFrexp(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) x = np.where(np.isfinite(x), x, float_type(1)) mant1, exp1 = np.frexp(x) mant2, exp2 = np.frexp(x.astype(np.float32)) np.testing.assert_equal(exp1, exp2) kwargs = {"rtol": 0.01} if float_type == float6_e2m3fn: kwargs = {"rtol": 0.1} elif float_type == float4_e2m1fn: kwargs = {"atol": 0.25} numpy_assert_allclose(mant1, mant2, float_type=float_type, **kwargs) def testCopySign(self, float_type): if not dtype_is_signed(float_type): raise self.skipTest("Skip copy sign test for unsigned floating formats.") bits_type = np.uint16 if float_type == bfloat16 else np.uint8 bit_size = ml_dtypes.finfo(float_type).bits bit_sign = 1 << (bit_size - 1) for bits in range(1, min(bit_sign, 256)): with self.subTest(bits): val = bits_type(bits).view(float_type) val_with_sign = np.copysign(val, float_type(-1)) val_with_sign_bits = val_with_sign.view(bits_type) self.assertEqual(bits | bit_sign, val_with_sign_bits) def testNextAfter(self, float_type): one = np.array(1.0, dtype=float_type) two = np.array(2.0, dtype=float_type) zero = np.array(0.0, dtype=float_type) np.testing.assert_equal( np.nextafter(one, two) - one, ml_dtypes.finfo(float_type).eps ) np.testing.assert_equal( np.nextafter(one, zero) - one, -ml_dtypes.finfo(float_type).epsneg ) np.testing.assert_equal(np.nextafter(one, one), one) smallest_denormal = ml_dtypes.finfo(float_type).smallest_subnormal if dtype_is_signed(float_type): np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal) np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal) if dtype_has_nan(float_type): nan = np.array(np.nan, dtype=float_type) np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True) np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True) for a, b in itertools.permutations([0.0, nan], 2): np.testing.assert_equal( np.nextafter( np.array(a, dtype=np.float32), np.array(b, dtype=np.float32) ), np.nextafter( np.array(a, dtype=float_type), np.array(b, dtype=float_type) ), ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testSpacing(self, float_type): # Sweep a variety of binades to see that spacing gives the proper ULP. with self.subTest(name="Subnormals"): for i in range( int(np.log2(float(ml_dtypes.finfo(float_type).smallest_subnormal))), int(np.log2(float(ml_dtypes.finfo(float_type).smallest_normal))), ): power_of_two = float_type(2.0**i) distance = ml_dtypes.finfo(float_type).smallest_subnormal np.testing.assert_equal(np.spacing(power_of_two), distance) np.testing.assert_equal(np.spacing(-power_of_two), -distance) # Normals have a distance which depends on their binade. with self.subTest(name="Normals"): for i in range( int(np.log2(float(ml_dtypes.finfo(float_type).smallest_normal))), int(np.log2(float(ml_dtypes.finfo(float_type).max))), ): power_of_two = float_type(2.0**i) distance = ml_dtypes.finfo(float_type).eps * power_of_two np.testing.assert_equal(np.spacing(power_of_two), distance) if dtype_is_signed(float_type): np.testing.assert_equal(np.spacing(-power_of_two), -distance) # Check that spacing agrees with arithmetic involving nextafter. with self.subTest(name="NextAfter"): for x in FLOAT_VALUES[float_type]: x_float_type = float_type(x) spacing = np.spacing(x_float_type) toward = np.copysign(float_type(2.0 * np.abs(x) + 1), x_float_type) nextup = np.nextafter(x_float_type, toward) if np.isnan(spacing): self.assertTrue(np.isnan(nextup - x_float_type)) elif spacing: np.testing.assert_equal(spacing, nextup - x_float_type) else: # If type has no NaN or infinity, spacing of the maximum value is # expected to be zero (next value does not exist). self.assertFalse(dtype_has_nan(float_type)) self.assertEqual(abs(x_float_type), ml_dtypes.finfo(float_type).max) # Check that spacing for special values gives the correct answer. with self.subTest(name="NonFinite"): if dtype_has_nan(float_type): nan = float_type(float("nan")) np.testing.assert_equal(np.spacing(nan), np.spacing(np.float32(nan))) if dtype_has_inf(float_type): inf = float_type(float("inf")) np.testing.assert_equal(np.spacing(inf), np.spacing(np.float32(inf))) if __name__ == "__main__": absltest.main() jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/finfo_test.py000066400000000000000000000110751510671665600232150ustar00rootroot00000000000000# Copyright 2022 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest from absl.testing import parameterized import ml_dtypes from multi_thread_utils import multi_threaded import numpy as np ALL_DTYPES = [ ml_dtypes.bfloat16, ml_dtypes.float4_e2m1fn, ml_dtypes.float6_e2m3fn, ml_dtypes.float6_e3m2fn, ml_dtypes.float8_e3m4, ml_dtypes.float8_e4m3, ml_dtypes.float8_e4m3b11fnuz, ml_dtypes.float8_e4m3fn, ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e5m2, ml_dtypes.float8_e5m2fnuz, ml_dtypes.float8_e8m0fnu, ] DTYPES_WITH_NO_INFINITY = [ ml_dtypes.float8_e4m3b11fnuz, ml_dtypes.float8_e4m3fn, ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e5m2fnuz, ml_dtypes.float8_e8m0fnu, ] DTYPES_WITH_NO_INFINITY_AND_NO_NAN = [ ml_dtypes.float4_e2m1fn, ml_dtypes.float6_e2m3fn, ml_dtypes.float6_e3m2fn, ] UINT_TYPES = { 4: np.uint8, 6: np.uint8, 8: np.uint8, 16: np.uint16, } @multi_threaded(num_workers=3) class FinfoTest(parameterized.TestCase): def assertNanEqual(self, x, y): if np.isnan(x) and np.isnan(y): return self.assertEqual(x, y) @parameterized.named_parameters( {"testcase_name": f"_{dtype.__name__}", "dtype": np.dtype(dtype)} for dtype in ALL_DTYPES ) def testFInfo(self, dtype): info = ml_dtypes.finfo(dtype) assert ml_dtypes.finfo(dtype.name) is info assert ml_dtypes.finfo(dtype.type) is info _ = str(info) # doesn't crash def make_val(val): return np.array(val, dtype=dtype) def assert_representable(val): self.assertEqual(make_val(val).item(), val) def assert_infinite(val): val = make_val(val) if dtype in DTYPES_WITH_NO_INFINITY_AND_NO_NAN: self.assertEqual(val, info.max) elif dtype in DTYPES_WITH_NO_INFINITY: self.assertTrue(np.isnan(val), f"expected NaN, got {val}") else: self.assertTrue(np.isposinf(val), f"expected inf, got {val}") def assert_zero(val): self.assertEqual(make_val(val), make_val(0)) self.assertEqual(np.array(0, dtype).dtype, dtype) self.assertIs(info.dtype, dtype) if info.bits >= 8: self.assertEqual(info.bits, np.array(0, dtype).itemsize * 8) # Unsigned float => no sign bit. if info.min >= 0.0: self.assertEqual(info.nmant + info.nexp, info.bits) else: self.assertEqual(info.nmant + info.nexp + 1, info.bits) assert_representable(info.tiny) assert_representable(info.max) assert_representable(info.min) if dtype not in DTYPES_WITH_NO_INFINITY_AND_NO_NAN: assert_infinite(np.spacing(info.max)) assert info.max > 0.0 if info.min < 0.0 and dtype not in DTYPES_WITH_NO_INFINITY_AND_NO_NAN: # Only valid for signed floating format. assert_infinite(-np.spacing(info.min)) elif info.min > 0.0: # No zero in floating point format. assert_infinite(0) assert_infinite(make_val(-1)) elif info.min == 0.0: # Zero supported, but not negative values. self.assertEqual(make_val(0), 0) assert_infinite(make_val(-1)) assert_representable(2.0 ** (info.maxexp - 1)) assert_infinite(2.0**info.maxexp) assert_representable(info.smallest_subnormal) if info.min < 0.0: assert_zero(info.smallest_subnormal * 0.5) self.assertGreater(info.smallest_normal, 0.0) self.assertEqual(info.tiny, info.smallest_normal) # Identities according to the documentation: np.testing.assert_allclose(info.resolution, make_val(10**-info.precision)) self.assertEqual(info.epsneg, make_val(2**info.negep)) self.assertEqual(info.eps, make_val(2**info.machep)) self.assertEqual(info.iexp, info.nexp) is_min_exponent_valid_normal = ( make_val(2**info.minexp) == info.smallest_normal ) # Check that minexp is consistent with nmant (subnormal representation) if not is_min_exponent_valid_normal and info.nmant > 0: self.assertEqual( make_val(2**info.minexp).view(UINT_TYPES[info.bits]), 2**info.nmant, ) if __name__ == "__main__": absltest.main() jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/float8_test.cc000066400000000000000000001367111510671665600232530ustar00rootroot00000000000000/* Copyright 2022 The ml_dtypes Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "ml_dtypes/include/float8.h" #include #include #include #include #include #include #include #include "unsupported/Eigen/CXX11/Tensor" namespace ml_dtypes { namespace { template class Float8Test : public ::testing::Test {}; // Helper utility for prettier test names. struct Float8TestParamNames { template static std::string GetName(int idx) { if constexpr (std::is_same_v) { return "float8_e4m3fn"; } else if constexpr (std::is_same_v) { return "float8_e4m3b11fnuz"; } else if constexpr (std::is_same_v) { return "float8_e3m4"; } else if constexpr (std::is_same_v) { return "float8_e4m3"; } else if constexpr (std::is_same_v) { return "float8_e5m2"; } else if constexpr (std::is_same_v) { return "float8_e4m3fnuz"; } else if constexpr (std::is_same_v) { return "float8_e5m2fnuz"; } else if constexpr (std::is_same_v) { return "float8_e8m0fnu"; } return ""; } }; using Float8Types = ::testing::Types; TYPED_TEST_SUITE(Float8Test, Float8Types, Float8TestParamNames); TEST(Float8E3m4Test, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::signaling_NaN())); EXPECT_EQ(static_cast(std::numeric_limits::min()), 0.25); EXPECT_EQ(static_cast(std::numeric_limits::max()), 15.5); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -15.5); EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0.0625); EXPECT_EQ(static_cast(std::numeric_limits::round_error()), 0.5); EXPECT_TRUE( Eigen::numext::isinf(std::numeric_limits::infinity())); EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), std::exp2(-6)); EXPECT_EQ(std::numeric_limits::digits, 5); EXPECT_EQ(std::numeric_limits::digits10, 1); EXPECT_EQ(std::numeric_limits::max_digits10, 3); EXPECT_EQ(std::numeric_limits::min_exponent, -1); EXPECT_EQ(std::numeric_limits::min_exponent10, 0); EXPECT_EQ(std::numeric_limits::max_exponent, 4); EXPECT_EQ(std::numeric_limits::max_exponent10, 1); EXPECT_EQ(std::numeric_limits::is_iec559, true); EXPECT_EQ(std::numeric_limits::has_infinity, true); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, true); } TEST(Float8E4m3Test, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::signaling_NaN())); EXPECT_EQ(static_cast(std::numeric_limits::min()), std::exp2(-6)); EXPECT_EQ(static_cast(std::numeric_limits::max()), 240); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -240); EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0.125); EXPECT_EQ(static_cast(std::numeric_limits::round_error()), 0.5); EXPECT_TRUE( Eigen::numext::isinf(std::numeric_limits::infinity())); EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), std::exp2(-9)); EXPECT_EQ(std::numeric_limits::digits, 4); EXPECT_EQ(std::numeric_limits::digits10, 0); EXPECT_EQ(std::numeric_limits::max_digits10, 3); EXPECT_EQ(std::numeric_limits::min_exponent, -5); EXPECT_EQ(std::numeric_limits::min_exponent10, -1); EXPECT_EQ(std::numeric_limits::max_exponent, 8); EXPECT_EQ(std::numeric_limits::max_exponent10, 2); EXPECT_EQ(std::numeric_limits::is_iec559, true); EXPECT_EQ(std::numeric_limits::has_infinity, true); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, true); } TEST(Float8E4m3fnTest, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); EXPECT_TRUE(Eigen::numext::isnan( std::numeric_limits::signaling_NaN())); EXPECT_EQ(static_cast(std::numeric_limits::min()), std::exp2(-6)); EXPECT_EQ(static_cast(std::numeric_limits::max()), 448); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -448); EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0.125); EXPECT_EQ( static_cast(std::numeric_limits::round_error()), 0.5); // No infinity, represent as NaN. EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::infinity())); EXPECT_EQ( static_cast(std::numeric_limits::denorm_min()), std::exp2(-9)); EXPECT_EQ(std::numeric_limits::digits, 4); EXPECT_EQ(std::numeric_limits::digits10, 0); EXPECT_EQ(std::numeric_limits::max_digits10, 3); EXPECT_EQ(std::numeric_limits::min_exponent, -5); EXPECT_EQ(std::numeric_limits::min_exponent10, -1); EXPECT_EQ(std::numeric_limits::max_exponent, 9); EXPECT_EQ(std::numeric_limits::max_exponent10, 2); EXPECT_EQ(std::numeric_limits::is_iec559, false); EXPECT_EQ(std::numeric_limits::has_infinity, false); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); } TEST(Float8E4m3b11fnuzTest, NumericLimits) { EXPECT_TRUE(Eigen::numext::isnan( std::numeric_limits::quiet_NaN())); EXPECT_TRUE(Eigen::numext::isnan( std::numeric_limits::signaling_NaN())); EXPECT_EQ(static_cast(std::numeric_limits::min()), std::exp2(-10)); EXPECT_EQ(static_cast(std::numeric_limits::max()), 30); EXPECT_EQ( static_cast(std::numeric_limits::lowest()), -30); EXPECT_EQ( static_cast(std::numeric_limits::epsilon()), 0.125); EXPECT_EQ(static_cast( std::numeric_limits::round_error()), 0.5); // No infinity, represent as NaN. EXPECT_TRUE(Eigen::numext::isnan( std::numeric_limits::infinity())); EXPECT_EQ( static_cast(std::numeric_limits::denorm_min()), std::exp2(-13)); EXPECT_EQ(std::numeric_limits::digits, 4); EXPECT_EQ(std::numeric_limits::digits10, 0); EXPECT_EQ(std::numeric_limits::max_digits10, 3); EXPECT_EQ(std::numeric_limits::min_exponent, -9); EXPECT_EQ(std::numeric_limits::min_exponent10, -3); EXPECT_EQ(std::numeric_limits::max_exponent, 5); EXPECT_EQ(std::numeric_limits::max_exponent10, 1); EXPECT_EQ(std::numeric_limits::is_iec559, false); EXPECT_EQ(std::numeric_limits::has_infinity, false); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); } TEST(Float8E4m3fnuzTest, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); EXPECT_TRUE(Eigen::numext::isnan( std::numeric_limits::signaling_NaN())); EXPECT_EQ(static_cast(std::numeric_limits::min()), std::exp2(-7)); EXPECT_EQ(static_cast(std::numeric_limits::max()), 240); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -240); EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0.125); EXPECT_EQ( static_cast(std::numeric_limits::round_error()), 0.5); // No infinity, represent as NaN. EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::infinity())); EXPECT_EQ( static_cast(std::numeric_limits::denorm_min()), std::exp2(-10)); EXPECT_EQ(std::numeric_limits::digits, 4); EXPECT_EQ(std::numeric_limits::digits10, 0); EXPECT_EQ(std::numeric_limits::max_digits10, 3); EXPECT_EQ(std::numeric_limits::min_exponent, -6); EXPECT_EQ(std::numeric_limits::min_exponent10, -2); EXPECT_EQ(std::numeric_limits::max_exponent, 8); EXPECT_EQ(std::numeric_limits::max_exponent10, 2); EXPECT_EQ(std::numeric_limits::is_iec559, false); EXPECT_EQ(std::numeric_limits::has_infinity, false); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); } TEST(Float8E5m2Test, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::signaling_NaN())); EXPECT_EQ(static_cast(std::numeric_limits::min()), std::exp2(-14)); EXPECT_EQ(static_cast(std::numeric_limits::max()), 57344); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -57344); EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0.25); EXPECT_EQ(static_cast(std::numeric_limits::round_error()), 0.5); EXPECT_TRUE( Eigen::numext::isinf(std::numeric_limits::infinity())); EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), std::exp2(-16)); EXPECT_EQ(std::numeric_limits::digits, 3); EXPECT_EQ(std::numeric_limits::digits10, 0); EXPECT_EQ(std::numeric_limits::max_digits10, 2); EXPECT_EQ(std::numeric_limits::min_exponent, -13); EXPECT_EQ(std::numeric_limits::min_exponent10, -4); EXPECT_EQ(std::numeric_limits::max_exponent, 16); EXPECT_EQ(std::numeric_limits::max_exponent10, 4); EXPECT_EQ(std::numeric_limits::is_iec559, true); EXPECT_EQ(std::numeric_limits::has_infinity, true); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, true); } TEST(Float8E5m2fnuzTest, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); EXPECT_TRUE(Eigen::numext::isnan( std::numeric_limits::signaling_NaN())); EXPECT_EQ(static_cast(std::numeric_limits::min()), std::exp2(-15)); EXPECT_EQ(static_cast(std::numeric_limits::max()), 57344); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -57344); EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0.25); EXPECT_EQ( static_cast(std::numeric_limits::round_error()), 0.5); // No infinity, represented as NaN. EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::infinity())); EXPECT_EQ( static_cast(std::numeric_limits::denorm_min()), std::exp2(-17)); EXPECT_EQ(std::numeric_limits::digits, 3); EXPECT_EQ(std::numeric_limits::digits10, 0); EXPECT_EQ(std::numeric_limits::max_digits10, 2); EXPECT_EQ(std::numeric_limits::min_exponent, -14); EXPECT_EQ(std::numeric_limits::min_exponent10, -4); EXPECT_EQ(std::numeric_limits::max_exponent, 16); EXPECT_EQ(std::numeric_limits::max_exponent10, 4); EXPECT_EQ(std::numeric_limits::is_iec559, false); EXPECT_EQ(std::numeric_limits::has_infinity, false); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); } TEST(Float8E8m0fnuTest, NumericLimits) { using limits = std::numeric_limits; EXPECT_FALSE(limits::is_signed); EXPECT_TRUE(Eigen::numext::isnan(limits::quiet_NaN())); EXPECT_TRUE(Eigen::numext::isnan(limits::signaling_NaN())); EXPECT_TRUE(Eigen::numext::isnan(limits::infinity())); // No infinity. EXPECT_EQ(static_cast(limits::min()), 0x1p-127); EXPECT_EQ(static_cast(limits::max()), 0x1p+127); EXPECT_EQ(static_cast(limits::lowest()), 0x1p-127); EXPECT_EQ(static_cast(limits::epsilon()), 1.0); EXPECT_EQ(static_cast(limits::round_error()), 0.5); EXPECT_EQ(limits::digits, 1); EXPECT_EQ(limits::digits10, 0); EXPECT_EQ(limits::max_digits10, 2); EXPECT_EQ(limits::min_exponent, -126); EXPECT_EQ(limits::min_exponent10, -38); EXPECT_EQ(limits::max_exponent, 128); EXPECT_EQ(limits::max_exponent10, 38); EXPECT_EQ(limits::is_iec559, false); EXPECT_EQ(limits::has_infinity, false); EXPECT_EQ(limits::has_quiet_NaN, true); EXPECT_EQ(limits::has_signaling_NaN, false); } TYPED_TEST(Float8Test, FromRep) { using Float8 = TypeParam; Float8 x = Float8::FromRep(0x4F); EXPECT_EQ(x.rep(), 0x4F); } TYPED_TEST(Float8Test, Negate) { using Float8 = TypeParam; if (!std::numeric_limits::is_signed) { GTEST_SKIP() << "Type doesn't support negative numbers"; } Float8 x = -Float8::FromRep(0x4F); EXPECT_EQ(x.rep(), 0x80 | 0x4F); Float8 nan = -std::numeric_limits::quiet_NaN(); EXPECT_TRUE(Eigen::numext::isnan(nan)); } TYPED_TEST(Float8Test, BitCasts) { using Float8 = TypeParam; Float8 x = Float8::FromRep(0x47); EXPECT_EQ(Eigen::numext::bit_cast(x), 0x47); EXPECT_EQ(Eigen::numext::bit_cast(x.rep()).rep(), 0x47); } TYPED_TEST(Float8Test, UpCasts) { using Float8 = TypeParam; // Loop through each float8 value. for (int i = 0x00; i <= 0xFF; ++i) { // Cast up to each other floating-point type, and verify they are the same. Float8 f8 = Float8::FromRep(i); double f64 = static_cast(f8); float f32 = static_cast(f8); Eigen::bfloat16 bf16 = static_cast(f8); Eigen::half f16 = static_cast(f8); if (Eigen::numext::isnan(f8)) { EXPECT_TRUE(Eigen::numext::isnan(f64)); EXPECT_TRUE(Eigen::numext::isnan(f32)); EXPECT_TRUE(Eigen::numext::isnan(bf16)); EXPECT_TRUE(Eigen::numext::isnan(f16)); } else { EXPECT_EQ(f64, f32); EXPECT_EQ(f32, bf16); // E8M0 exponent range doesn't fit F16 type. if (!std::is_same_v) { EXPECT_EQ(bf16, f16); } } } } TYPED_TEST(Float8Test, DownCasts) { using Float8 = TypeParam; for (int i = 0x00; i <= 0xFF; ++i) { float x = static_cast(Float8::FromRep(i)); Float8 f64 = static_cast(static_cast(x)); Float8 f32 = static_cast(static_cast(x)); Float8 bf16 = static_cast(static_cast(x)); Float8 f16 = static_cast(static_cast(x)); if (Eigen::numext::isnan(x)) { EXPECT_TRUE(Eigen::numext::isnan(f64)); EXPECT_TRUE(Eigen::numext::isnan(f32)); EXPECT_TRUE(Eigen::numext::isnan(bf16)); EXPECT_TRUE(Eigen::numext::isnan(f16)); } else { EXPECT_EQ(f64.rep(), i) << i; EXPECT_EQ(f32.rep(), i) << i; EXPECT_EQ(bf16.rep(), i) << i; // E8M0 exponent range doesn't fit F16 type. if (!std::is_same_v) { EXPECT_EQ(f16.rep(), i) << i; } } } } TYPED_TEST(Float8Test, ConvertFromWithSaturation) { using Float8 = TypeParam; // Saturation above max value. Float8 upper = Float8::template ConvertFrom( static_cast(std::numeric_limits::max()) * 2); EXPECT_EQ(upper, std::numeric_limits::max()); if (std::numeric_limits::is_signed) { Float8 lower = Float8::template ConvertFrom( static_cast(std::numeric_limits::lowest()) * 2); EXPECT_EQ(lower, std::numeric_limits::lowest()); } // Special values remain with saturation. Float8 nan = Float8::template ConvertFrom( std::numeric_limits::quiet_NaN()); EXPECT_TRUE(Eigen::numext::isnan(nan)); Float8 inf = Float8::template ConvertFrom( std::numeric_limits::infinity()); // E4M3 doesn't have inf, so check inf -> NaN conversion. EXPECT_TRUE(std::numeric_limits::has_infinity ? Eigen::numext::isinf(inf) : Eigen::numext::isnan(inf)); Float8 ninf = Float8::template ConvertFrom( -std::numeric_limits::infinity()); EXPECT_TRUE(std::numeric_limits::has_infinity ? Eigen::numext::isinf(ninf) : Eigen::numext::isnan(ninf)); } TYPED_TEST(Float8Test, ConvertFromWithTruncation) { using Float8 = TypeParam; // Truncation and rounding of a number ever-so-slightly less than 2. float less_than_two = Eigen::numext::bit_cast(0x3FFFFFFF); Float8 truncated = Float8::template ConvertFrom( less_than_two); EXPECT_LT(static_cast(truncated), 2); Float8 rounded = Float8::template ConvertFrom( less_than_two); EXPECT_EQ(static_cast(rounded), 2); double kLarge = 0x1p+128; EXPECT_EQ( (Float8::template ConvertFrom( kLarge) .rep()), std::numeric_limits::infinity().rep()); EXPECT_EQ( (Float8::template ConvertFrom( kLarge) .rep()), std::numeric_limits::infinity().rep()); // Truncation and rounding of a subnormal. for (int i = 0x01; i < 0x04; ++i) { float less_than_subnorm = std::nexttoward(static_cast(Float8::FromRep(i)), 0); Float8 truncated_subnorm = Float8::template ConvertFrom( less_than_subnorm); EXPECT_EQ(truncated_subnorm.rep(), i - 1); Float8 rounded_subnorm = Float8::template ConvertFrom( less_than_subnorm); EXPECT_EQ(rounded_subnorm.rep(), i); } } TYPED_TEST(Float8Test, ConvertTo) { using Float8 = TypeParam; // Converting to higher precision types doesn't result in either // truncation or saturation, so let's just ensure they all provide the // same results. for (int i = 0x00; i <= 0xFF; ++i) { // Cast up to each other floating-point type, and verify they are the same. Float8 f8 = Float8::FromRep(i); float f32 = static_cast(f8); if (Eigen::numext::isnan(f8)) { EXPECT_TRUE( std::isnan(Float8::template ConvertTo(f8))); EXPECT_TRUE( std::isnan(Float8::template ConvertTo(f8))); EXPECT_TRUE( std::isnan(Float8::template ConvertTo(f8))); EXPECT_TRUE( std::isnan(Float8::template ConvertTo(f8))); } else { EXPECT_EQ(f32, (Float8::template ConvertTo(f8))); EXPECT_EQ(f32, (Float8::template ConvertTo(f8))); EXPECT_EQ(f32, (Float8::template ConvertTo(f8))); EXPECT_EQ(f32, (Float8::template ConvertTo(f8))); } } } template static SrcType DoubleRoundHelper() { // If we have a number of the form 1.0..010..010.., two rounds of RTNE can // cause the last-set bit to get rounded down due to RTNE which in turn will // cause the other bit to get rounded down due to RTNE. RTNE's tie breaking // semantics *should* not apply here as there is no tie but double-rounding // may confuse us. SrcType x{1.0}; x += std::ldexp(SrcType{1.0}, -std::numeric_limits::digits); x += std::ldexp(SrcType{1.0}, -std::numeric_limits::digits); auto rounded_x = static_cast(x); return static_cast(rounded_x); } // This test tries to capture mistakes in `float8_base::ConvertFrom` where it is // implemented by a series of conversions. e.g. converting a double to a float // to a float8 introduces double-rounding which makes the final rounding step // unfaithful. Craft a variety of numbers which try to detect if this happens. TYPED_TEST(Float8Test, DoubleRound) { using Float8 = TypeParam; // We expect that our number results in rounding up to the number after 1. // Incorrect rounding will result in 1. const double expected = 1.0 + static_cast(std::numeric_limits::epsilon()); EXPECT_EQ((DoubleRoundHelper()), expected); // Don't use long double on targets which don't support it. #if !defined(EIGEN_USE_GPU) && !defined(EIGEN_GPU_COMPILE_PHASE) EXPECT_EQ((DoubleRoundHelper()), expected); EXPECT_EQ((DoubleRoundHelper()), expected); Float8 max = std::numeric_limits::max(); Float8 saturated = Float8::template ConvertFrom( std::numeric_limits::max()); EXPECT_EQ(max, saturated); #endif } TEST(Float8Test, Float8E5m2_To_Float8E4m3fn) { // Saturation. float8_e5m2 max = std::numeric_limits::max(); float8_e4m3fn saturated = float8_e4m3fn::ConvertFrom(max); EXPECT_EQ(saturated, std::numeric_limits::max()); saturated = float8_e5m2::ConvertTo(max); EXPECT_EQ(saturated, std::numeric_limits::max()); // Truncation - only occurs for e4m3 subnormals. float8_e5m2 less_than_subnorm = float8_e5m2::FromRep(0x1F); // 2^-7 - 2^-10. float8_e4m3fn rounded_subnorm = float8_e4m3fn::ConvertFrom( less_than_subnorm); EXPECT_EQ(rounded_subnorm.rep(), 0x04); float8_e4m3fn truncated_subnorm = float8_e4m3fn::ConvertFrom( less_than_subnorm); EXPECT_EQ(truncated_subnorm.rep(), 0x03); } TEST(Float8Test, Half_To_Float8E4m3fn) { Eigen::half big_half(0x1.dfcp+8f); float8_e4m3fn big_e4m3fn = float8_e4m3fn::ConvertFrom( big_half); EXPECT_EQ(big_e4m3fn.rep(), std::numeric_limits::max().rep()); } TEST(Float8Test, Float8E5m2_To_Float8E4m3b11fnuz) { // Saturation. float8_e5m2 max = std::numeric_limits::max(); float8_e4m3b11fnuz saturated = float8_e4m3b11fnuz::ConvertFrom(max); EXPECT_EQ(saturated, std::numeric_limits::max()); saturated = float8_e5m2::ConvertTo(max); EXPECT_EQ(saturated, std::numeric_limits::max()); // Truncation - only occurs for e4m3 subnormals. float8_e5m2 less_than_subnorm = float8_e5m2::FromRep(0x0F); // 2^-11 - 2^-14. float8_e4m3b11fnuz rounded_subnorm = float8_e4m3b11fnuz::ConvertFrom( less_than_subnorm); EXPECT_EQ(rounded_subnorm.rep(), 0x04); float8_e4m3b11fnuz truncated_subnorm = float8_e4m3b11fnuz::ConvertFrom( less_than_subnorm); EXPECT_EQ(truncated_subnorm.rep(), 0x03); // Saturation. for (uint8_t i = 0; i < std::numeric_limits::infinity().rep(); ++i) { float8_e5m2 big_e5m2 = Eigen::numext::bit_cast(i); EXPECT_TRUE(Eigen::numext::isfinite(big_e5m2)) << uint16_t{i}; float big_float = static_cast(big_e5m2); auto big_e4m3 = float8_e4m3b11fnuz::ConvertFrom(big_float); if (i > 0x4f) { EXPECT_EQ(big_e4m3.rep(), std::numeric_limits::max().rep()) << uint16_t{i}; } EXPECT_EQ((float8_e4m3b11fnuz::ConvertFrom(big_e5m2) .rep()), big_e4m3.rep()) << i; EXPECT_EQ((float8_e4m3b11fnuz::ConvertFrom(-big_e5m2) .rep()), (-big_e4m3).rep()) << i; } } TEST(Float8Test, Float8E4m3b11fnuz_To_Float8E4m3fn) { // Saturation. float8_e4m3b11fnuz max = std::numeric_limits::max(); float8_e4m3fn saturated = float8_e4m3fn::ConvertFrom(max); EXPECT_EQ(static_cast(saturated), static_cast(std::numeric_limits::max())); saturated = float8_e4m3b11fnuz::ConvertTo(max); EXPECT_EQ(static_cast(saturated), static_cast(std::numeric_limits::max())); // Truncation - only occurs for e4m3 subnormals. float8_e4m3b11fnuz less_than_subnorm = float8_e4m3b11fnuz::FromRep(0b0011'110); // 2^-7 - 2^-10. float8_e4m3fn rounded_subnorm = float8_e4m3fn::ConvertFrom( less_than_subnorm); EXPECT_EQ(rounded_subnorm.rep(), 0x04); float8_e4m3fn truncated_subnorm = float8_e4m3fn::ConvertFrom( less_than_subnorm); EXPECT_EQ(truncated_subnorm.rep(), 0x03); // Saturation. for (uint8_t i = 0; i < std::numeric_limits::infinity().rep(); ++i) { float8_e4m3b11fnuz big_e4m3b11fnuz = Eigen::numext::bit_cast(i); EXPECT_TRUE(Eigen::numext::isfinite(big_e4m3b11fnuz)) << uint16_t{i}; float big_float = static_cast(big_e4m3b11fnuz); auto big_e4m3 = float8_e4m3fn::ConvertFrom( big_float); EXPECT_EQ( (float8_e4m3fn::ConvertFrom( big_e4m3b11fnuz) .rep()), big_e4m3.rep()) << i; EXPECT_EQ( (float8_e4m3fn::ConvertFrom( -big_e4m3b11fnuz) .rep()), (big_float > 0.0f ? -big_e4m3 : big_e4m3).rep()) << i; } } TEST(Float8Test, Float8E3m4_To_Float8E5m2) { // Truncation and rounding of a number ever-so-slightly less than 2. float8_e3m4 less_than_two = float8_e3m4::FromRep(0x3F); float8_e5m2 truncated = float8_e5m2::template ConvertFrom(less_than_two); EXPECT_LT(static_cast(truncated), 2); float8_e5m2 rounded = float8_e5m2::template ConvertFrom(less_than_two); EXPECT_EQ(static_cast(rounded), 2); } TEST(Float8Test, Float8E4m3_To_Float8E5m2) { // Truncation and rounding of a number ever-so-slightly less than 2. float8_e4m3 less_than_two = float8_e4m3::FromRep(0x3F); float8_e5m2 truncated = float8_e5m2::template ConvertFrom(less_than_two); EXPECT_LT(static_cast(truncated), 2); float8_e5m2 rounded = float8_e5m2::template ConvertFrom(less_than_two); EXPECT_EQ(static_cast(rounded), 2); } TEST(Float8Test, Float8E4m3fn_To_Float8E5m2) { // Truncation and rounding of a number ever-so-slightly less than 2. float8_e4m3fn less_than_two = float8_e4m3fn::FromRep(0x3F); float8_e5m2 truncated = float8_e5m2::template ConvertFrom(less_than_two); EXPECT_LT(static_cast(truncated), 2); float8_e5m2 rounded = float8_e5m2::template ConvertFrom(less_than_two); EXPECT_EQ(static_cast(rounded), 2); } TEST(Float8Test, Half_To_Float8E3m4) { // Special values, NaN. Eigen::half inf = Eigen::numext::bit_cast(static_cast(0x7C00)); EXPECT_EQ(static_cast(inf).rep(), 0x70); Eigen::half ninf = Eigen::numext::bit_cast(static_cast(0xFC00)); EXPECT_EQ(static_cast(ninf).rep(), 0xF0); Eigen::half nan = Eigen::numext::bit_cast(static_cast(0x7C01)); EXPECT_EQ(static_cast(nan).rep(), 0x78); Eigen::half nnan = Eigen::numext::bit_cast(static_cast(0xFC01)); EXPECT_EQ(static_cast(nnan).rep(), 0xF8); // Rounding vs truncation. Eigen::half less_than_two = Eigen::numext::bit_cast(static_cast(0x3FFF)); EXPECT_EQ((float8_e3m4::ConvertFrom(less_than_two) .rep()), 0x40); EXPECT_EQ((float8_e3m4::ConvertFrom(less_than_two) .rep()), 0x3F); EXPECT_EQ((float8_e3m4::ConvertFrom(-less_than_two) .rep()), 0xC0); EXPECT_EQ((float8_e3m4::ConvertFrom(-less_than_two) .rep()), 0xBF); // Saturation. // f8e3m4=0.110.1111 0x1.Fp+3 f16=0.10010.1111000000 uint16=0x4BC0 // f8e3m4=0.111.0000 0x1.0p+4 f16=0.10011.0000000000 uint16=0x4C00 for (uint16_t i = 0x4BC0; i < 0x4C00; ++i) { Eigen::half big_half = Eigen::numext::bit_cast(i); float big_float = static_cast(big_half); EXPECT_EQ( (float8_e3m4::ConvertFrom( big_half) .rep()), (float8_e3m4::ConvertFrom( big_float) .rep())) << i; EXPECT_EQ( (float8_e3m4::ConvertFrom( -big_half) .rep()), (float8_e3m4::ConvertFrom( -big_float) .rep())) << i; } } TEST(Float8Test, Half_To_Float8E4m3) { // Special values, NaN. Eigen::half inf = Eigen::numext::bit_cast(static_cast(0x7C00)); EXPECT_EQ(static_cast(inf).rep(), 0x78); Eigen::half ninf = Eigen::numext::bit_cast(static_cast(0xFC00)); EXPECT_EQ(static_cast(ninf).rep(), 0xF8); Eigen::half nan = Eigen::numext::bit_cast(static_cast(0x7C01)); EXPECT_EQ(static_cast(nan).rep(), 0x7C); Eigen::half nnan = Eigen::numext::bit_cast(static_cast(0xFC01)); EXPECT_EQ(static_cast(nnan).rep(), 0xFC); // Rounding vs truncation. Eigen::half less_than_two = Eigen::numext::bit_cast(static_cast(0x3FFF)); EXPECT_EQ((float8_e4m3::ConvertFrom(less_than_two) .rep()), 0x40); EXPECT_EQ((float8_e4m3::ConvertFrom(less_than_two) .rep()), 0x3F); EXPECT_EQ((float8_e4m3::ConvertFrom(-less_than_two) .rep()), 0xC0); EXPECT_EQ((float8_e4m3::ConvertFrom(-less_than_two) .rep()), 0xBF); // Saturation. // f8e4m3=0.1110.111 0x1.Ep+7 f16=0.10110.1110000000 uint16=0x5B80 // f8e4m3=0.1111.000 0x1.0p+8 f16=0.10111.0000000000 uint16=0x5C00 for (uint16_t i = 0x5B80; i < 0x5C00; ++i) { Eigen::half big_half = Eigen::numext::bit_cast(i); float big_float = static_cast(big_half); EXPECT_EQ( (float8_e4m3::ConvertFrom( big_half) .rep()), (float8_e4m3::ConvertFrom( big_float) .rep())) << i; EXPECT_EQ( (float8_e4m3::ConvertFrom( -big_half) .rep()), (float8_e4m3::ConvertFrom( -big_float) .rep())) << i; } } TEST(Float8Test, Half_To_Float8E5m2) { // Special values, NaN. Eigen::half inf = Eigen::numext::bit_cast(static_cast(0x7C00)); EXPECT_EQ(static_cast(inf).rep(), 0x7C); Eigen::half ninf = Eigen::numext::bit_cast(static_cast(0xFC00)); EXPECT_EQ(static_cast(ninf).rep(), 0xFC); Eigen::half nan = Eigen::numext::bit_cast(static_cast(0x7C01)); EXPECT_EQ(static_cast(nan).rep(), 0x7E); Eigen::half nnan = Eigen::numext::bit_cast(static_cast(0xFC01)); EXPECT_EQ(static_cast(nnan).rep(), 0xFE); // Rounding vs truncation. Eigen::half less_than_two = Eigen::numext::bit_cast(static_cast(0x3FFF)); EXPECT_EQ((float8_e5m2::ConvertFrom(less_than_two) .rep()), 0x40); EXPECT_EQ((float8_e5m2::ConvertFrom(less_than_two) .rep()), 0x3F); EXPECT_EQ((float8_e5m2::ConvertFrom(-less_than_two) .rep()), 0xC0); EXPECT_EQ((float8_e5m2::ConvertFrom(-less_than_two) .rep()), 0xBF); // Saturation. for (uint16_t i = static_cast(Eigen::numext::bit_cast( std::numeric_limits::max())) << 8; i < Eigen::numext::bit_cast( std::numeric_limits::infinity()); ++i) { Eigen::half big_half = Eigen::numext::bit_cast(i); float big_float = static_cast(big_half); EXPECT_EQ( (float8_e5m2::ConvertFrom( big_half) .rep()), (float8_e5m2::ConvertFrom( big_float) .rep())) << i; EXPECT_EQ( (float8_e5m2::ConvertFrom( -big_half) .rep()), (float8_e5m2::ConvertFrom( -big_float) .rep())) << i; } } using ::testing::Eq; using ::testing::IsTrue; MATCHER_P(EqOrIsNan, other, "") { if (Eigen::numext::isnan(other)) { return ExplainMatchResult(IsTrue(), Eigen::numext::isnan(arg), result_listener); } return ExplainMatchResult(Eq(other), arg, result_listener); } TYPED_TEST(Float8Test, CallTheOperator) { using Float8 = TypeParam; for (int i = 0x00; i <= 0xFF; ++i) { Float8 a = Float8::FromRep(i); for (int j = 0x00; j <= 0xFF; ++j) { Float8 b = Float8::FromRep(j); EXPECT_THAT(a + b, EqOrIsNan(Float8{float{a} + float{b}})); EXPECT_THAT(a - b, EqOrIsNan(Float8{float{a} - float{b}})); EXPECT_THAT(a * b, EqOrIsNan(Float8{float{a} * float{b}})); EXPECT_THAT(a / b, EqOrIsNan(Float8{float{a} / float{b}})); Float8 c; EXPECT_THAT((c = a, c += b), EqOrIsNan(Float8{float{a} + float{b}})); EXPECT_THAT((c = a, c -= b), EqOrIsNan(Float8{float{a} - float{b}})); EXPECT_THAT((c = a, c *= b), EqOrIsNan(Float8{float{a} * float{b}})); EXPECT_THAT((c = a, c /= b), EqOrIsNan(Float8{float{a} / float{b}})); EXPECT_EQ(a == b, float{a} == float{b}) << float{a} << " vs " << float{b}; EXPECT_EQ(a != b, float{a} != float{b}); EXPECT_EQ(a < b, float{a} < float{b}); EXPECT_EQ(a <= b, float{a} <= float{b}); EXPECT_EQ(a > b, float{a} > float{b}); EXPECT_EQ(a >= b, float{a} >= float{b}); } } } TYPED_TEST(Float8Test, CallTheConstOperator) { using Float8 = TypeParam; for (int i = 0x00; i <= 0xFF; ++i) { const Float8 a = Float8::FromRep(i); for (int j = 0x00; j <= 0xFF; ++j) { const Float8 b = Float8::FromRep(j); EXPECT_THAT(a + b, EqOrIsNan(Float8{float{a} + float{b}})); EXPECT_THAT(a - b, EqOrIsNan(Float8{float{a} - float{b}})); EXPECT_THAT(a * b, EqOrIsNan(Float8{float{a} * float{b}})); EXPECT_THAT(a / b, EqOrIsNan(Float8{float{a} / float{b}})); Float8 c; EXPECT_THAT((c = a, c += b), EqOrIsNan(Float8{float{a} + float{b}})); EXPECT_THAT((c = a, c -= b), EqOrIsNan(Float8{float{a} - float{b}})); EXPECT_THAT((c = a, c *= b), EqOrIsNan(Float8{float{a} * float{b}})); EXPECT_THAT((c = a, c /= b), EqOrIsNan(Float8{float{a} / float{b}})); EXPECT_EQ(a == b, float{a} == float{b}) << float{a} << " vs " << float{b}; EXPECT_EQ(a != b, float{a} != float{b}); EXPECT_EQ(a < b, float{a} < float{b}) << float{a} << " vs " << float{b}; EXPECT_EQ(a <= b, float{a} <= float{b}); EXPECT_EQ(a > b, float{a} > float{b}) << float{a} << " vs " << float{b}; EXPECT_EQ(a >= b, float{a} >= float{b}); } } } TEST(Float8E3m4Test, SmallCastToDenormal) { // Special edge-case where rounding to a normalized value would // normally round down, but rounding to a subnormal rounds up. float x = 0x0.8Ap-2; // btw denormals float8_e3m4 y = static_cast(x); float z = static_cast(y); EXPECT_EQ(z, 0x0.9p-2); // rounded up to the next denormal } TEST(Float8E4m3Test, SmallCastToDenormal) { // Special edge-case where rounding to a normalized value would // normally round down, but rounding to a subnormal rounds up. float x = 0x0.94p-6; // btw denormals float8_e4m3 y = static_cast(x); float z = static_cast(y); EXPECT_EQ(z, 0x0.Ap-6); // rounded up to the next denormal } TEST(Float8E5m2Test, SmallCastToDenormal) { // Special edge-case where rounding to a normalized value would // normally round down, but rounding to a subnormal rounds up. float x = 0x0.A8p-14; // btw denormals float8_e5m2 y = static_cast(x); float z = static_cast(y); EXPECT_EQ(z, 0x0.Cp-14); // rounded up to the next denormal } // Helper utility for prettier test names. struct Float8CastTestParamNames { template static std::string GetName(int idx) { using first_type = typename TypeParam::first_type; using second_type = typename TypeParam::second_type; return ::testing::internal::GetTypeName() + "_" + ::testing::internal::GetTypeName(); } }; #if !defined(EIGEN_USE_GPU) && !defined(EIGEN_GPU_COMPILE_PHASE) // long double doesn't work on GPU - it is treated as a regular 8-byte // double, which differs in size from the 16-byte long double on intel CPU. #define GEN_LONG_DOUBLE_PAIR(Type) std::pair, #else #define GEN_LONG_DOUBLE_PAIR(Type) #endif #define GEN_DEST_TYPES(Type) \ GEN_LONG_DOUBLE_PAIR(Type) \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair #define GEN_TYPE_PAIRS() \ GEN_DEST_TYPES(float8_e3m4), GEN_DEST_TYPES(float8_e4m3), \ GEN_DEST_TYPES(float8_e4m3fn), GEN_DEST_TYPES(float8_e4m3b11fnuz), \ GEN_DEST_TYPES(float8_e5m2), GEN_DEST_TYPES(float8_e4m3fnuz), \ GEN_DEST_TYPES(float8_e5m2fnuz), GEN_DEST_TYPES(float8_e8m0fnu) using Float8CastTypePairs = ::testing::Types; template class Float8CastTest : public ::testing::Test {}; TYPED_TEST_SUITE(Float8CastTest, Float8CastTypePairs, Float8CastTestParamNames); TYPED_TEST(Float8CastTest, CastThroughFloat) { using Float8 = typename TypeParam::first_type; using DestType = typename TypeParam::second_type; for (int i = 0x00; i <= 0xFF; ++i) { Float8 f8 = Float8::FromRep(i); if constexpr (std::numeric_limits::is_integer && !std::is_same_v) { if (!Eigen::numext::isfinite(f8) || static_cast(std::numeric_limits::max()) <= f8) { continue; } } DestType dest; // Eigen floats define a template constructor that turns the static_cast // into a cast from f8 to float to DestType, which is exactly what we have // in `expected`, so we special case float types here. if constexpr (!std::is_integral_v && !std::is_same_v) { dest = Float8::template ConvertTo(f8); } else { dest = static_cast(f8); } DestType expected = static_cast(static_cast(f8)); EXPECT_THAT(dest, EqOrIsNan(expected)); } } TYPED_TEST(Float8CastTest, DeviceCast) { using Float8 = typename TypeParam::first_type; using DestType = typename TypeParam::second_type; #if defined(EIGEN_USE_GPU) Eigen::GpuStreamDevice stream; Eigen::GpuDevice device(&stream); #elif defined(EIGEN_USE_THREADS) constexpr int kThreads = 4; Eigen::ThreadPool tp(kThreads); Eigen::ThreadPoolDevice device(&tp, kThreads); #else Eigen::DefaultDevice device; #endif const int kNumElems = 256; // Allocate device buffers and create device tensors. Float8* src_device_buffer = (Float8*)device.allocate(kNumElems * sizeof(Float8)); DestType* dst_device_buffer = (DestType*)device.allocate(kNumElems * sizeof(DestType)); Eigen::TensorMap, Eigen::Aligned> src_device( src_device_buffer, kNumElems); Eigen::TensorMap, Eigen::Aligned> dst_device( dst_device_buffer, kNumElems); // Allocate host buffers and initially src memory. Eigen::Tensor src_cpu(kNumElems); Eigen::Tensor dst_cpu(kNumElems); using limits = std::numeric_limits; for (int i = 0; i < kNumElems; ++i) { src_cpu(i) = Eigen::numext::bit_cast(static_cast(i)); // If src is inf or nan or has type overflow but DestType doesn't support // such values (e.g. integer types), replace the input with a zero. if ((!limits::has_quiet_NaN && Eigen::numext::isnan(src_cpu(i))) || (!limits::has_infinity && Eigen::numext::isinf(src_cpu(i))) || (limits::is_integer && !std::is_same_v && static_cast(limits::max()) <= src_cpu(i))) { src_cpu(i) = src_cpu(0); } } // Transfer data to device, perform a cast to DestType, then transfer result // back to host. device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), kNumElems * sizeof(Float8)); dst_device.device(device) = src_device.template cast(); device.memcpyDeviceToHost(dst_cpu.data(), dst_device_buffer, kNumElems * sizeof(DestType)); device.synchronize(); for (int i = 0; i < kNumElems; ++i) { DestType expected = static_cast(src_cpu(i)); EXPECT_THAT(dst_cpu(i), EqOrIsNan(expected)); } // Cast back from DestType to Float8. // First clear out the device src buffer, since that will be the destination. src_cpu.setZero(); device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), kNumElems * sizeof(Float8)); src_device.device(device) = dst_device.template cast(); device.memcpyDeviceToHost(src_cpu.data(), src_device_buffer, kNumElems * sizeof(Float8)); device.synchronize(); for (int i = 0; i < kNumElems; ++i) { Float8 expected = static_cast(dst_cpu(i)); EXPECT_THAT(src_cpu(i), EqOrIsNan(expected)); } // Clean up. device.deallocate(src_device_buffer); device.deallocate(dst_device_buffer); device.synchronize(); } } // namespace } // namespace ml_dtypes jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/iinfo_test.py000066400000000000000000000062321510671665600232170ustar00rootroot00000000000000# Copyright 2022 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest from absl.testing import parameterized import ml_dtypes from multi_thread_utils import multi_threaded import numpy as np @multi_threaded(num_workers=3) class IinfoTest(parameterized.TestCase): def testIinfoInt2(self): info = ml_dtypes.iinfo(ml_dtypes.int2) self.assertEqual(info.dtype, ml_dtypes.iinfo("int2").dtype) self.assertEqual(info.dtype, ml_dtypes.iinfo(np.dtype("int2")).dtype) self.assertEqual(info.min, -2) self.assertEqual(info.max, 1) self.assertEqual(info.dtype, np.dtype(ml_dtypes.int2)) self.assertEqual(info.bits, 2) self.assertEqual(info.kind, "i") self.assertEqual(str(info), "iinfo(min=-2, max=1, dtype=int2)") def testIInfoUint2(self): info = ml_dtypes.iinfo(ml_dtypes.uint2) self.assertEqual(info.dtype, ml_dtypes.iinfo("uint2").dtype) self.assertEqual(info.dtype, ml_dtypes.iinfo(np.dtype("uint2")).dtype) self.assertEqual(info.min, 0) self.assertEqual(info.max, 3) self.assertEqual(info.dtype, np.dtype(ml_dtypes.uint2)) self.assertEqual(info.bits, 2) self.assertEqual(info.kind, "u") self.assertEqual(str(info), "iinfo(min=0, max=3, dtype=uint2)") def testIinfoInt4(self): info = ml_dtypes.iinfo(ml_dtypes.int4) self.assertEqual(info.dtype, ml_dtypes.iinfo("int4").dtype) self.assertEqual(info.dtype, ml_dtypes.iinfo(np.dtype("int4")).dtype) self.assertEqual(info.min, -8) self.assertEqual(info.max, 7) self.assertEqual(info.dtype, np.dtype(ml_dtypes.int4)) self.assertEqual(info.bits, 4) self.assertEqual(info.kind, "i") self.assertEqual(str(info), "iinfo(min=-8, max=7, dtype=int4)") def testIInfoUint4(self): info = ml_dtypes.iinfo(ml_dtypes.uint4) self.assertEqual(info.dtype, ml_dtypes.iinfo("uint4").dtype) self.assertEqual(info.dtype, ml_dtypes.iinfo(np.dtype("uint4")).dtype) self.assertEqual(info.min, 0) self.assertEqual(info.max, 15) self.assertEqual(info.dtype, np.dtype(ml_dtypes.uint4)) self.assertEqual(info.bits, 4) self.assertEqual(info.kind, "u") self.assertEqual(str(info), "iinfo(min=0, max=15, dtype=uint4)") def testIinfoInt8(self): # Checks iinfo succeeds for a built-in NumPy type. info = ml_dtypes.iinfo(np.int8) self.assertEqual(info.min, -128) self.assertEqual(info.max, 127) def testIinfoNonInteger(self): with self.assertRaises(ValueError): ml_dtypes.iinfo(np.float32) with self.assertRaises(ValueError): ml_dtypes.iinfo(np.complex128) with self.assertRaises(ValueError): ml_dtypes.iinfo(bool) if __name__ == "__main__": absltest.main() jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/intn_test.cc000066400000000000000000000405311510671665600230200ustar00rootroot00000000000000/* Copyright 2023 The ml_dtypes Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "ml_dtypes/include/intn.h" #include #include #include #include #include #include #include #include #include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" namespace ml_dtypes { namespace { template struct is_intN : std::false_type {}; template struct is_intN> : std::true_type {}; template inline constexpr bool is_intN_v = is_intN::value; template class IntNTest : public ::testing::Test {}; // Helper utility for prettier test names. struct IntNTestParamNames { template static std::string GetName(int idx) { if constexpr (is_intN_v) { std::string name; name.reserve(5); if constexpr (std::is_unsigned_v) { name.append("u"); } name.append("int"); name.append(std::to_string(TypeParam::bits)); return name; } return std::to_string(idx); } }; using IntNTypes = ::testing::Types; TYPED_TEST_SUITE(IntNTest, IntNTypes, IntNTestParamNames); TEST(IntNTest, NumericLimits) { EXPECT_EQ(std::numeric_limits::is_signed, true); EXPECT_EQ(std::numeric_limits::is_modulo, false); EXPECT_EQ(static_cast(std::numeric_limits::min()), -8); EXPECT_EQ(static_cast(std::numeric_limits::max()), 7); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -8); EXPECT_EQ(std::numeric_limits::digits, 3); EXPECT_EQ(std::numeric_limits::digits10, 0); EXPECT_EQ(std::numeric_limits::is_signed, true); EXPECT_EQ(std::numeric_limits::is_modulo, false); EXPECT_EQ(static_cast(std::numeric_limits::min()), -1); EXPECT_EQ(static_cast(std::numeric_limits::max()), 0); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -1); EXPECT_EQ(std::numeric_limits::digits, 0); EXPECT_EQ(std::numeric_limits::digits10, 0); } TEST(UIntNTest, NumericLimits) { EXPECT_EQ(std::numeric_limits::is_signed, false); EXPECT_EQ(std::numeric_limits::is_modulo, true); EXPECT_EQ(static_cast(std::numeric_limits::min()), 0); EXPECT_EQ(static_cast(std::numeric_limits::max()), 15); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), 0); EXPECT_EQ(std::numeric_limits::digits, 4); EXPECT_EQ(std::numeric_limits::digits10, 1); EXPECT_EQ(std::numeric_limits::is_signed, false); EXPECT_EQ(std::numeric_limits::is_modulo, true); EXPECT_EQ(static_cast(std::numeric_limits::min()), 0); EXPECT_EQ(static_cast(std::numeric_limits::max()), 1); EXPECT_EQ(static_cast(std::numeric_limits::lowest()), 0); EXPECT_EQ(std::numeric_limits::digits, 1); EXPECT_EQ(std::numeric_limits::digits10, 0); } TYPED_TEST(IntNTest, NumericLimitsBase) { using IntN = TypeParam; EXPECT_EQ(std::numeric_limits::is_specialized, true); EXPECT_EQ(std::numeric_limits::is_integer, true); EXPECT_EQ(std::numeric_limits::is_exact, true); EXPECT_EQ(std::numeric_limits::has_infinity, false); EXPECT_EQ(std::numeric_limits::has_quiet_NaN, false); EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); #if !defined(__cplusplus) || __cplusplus < 202302L EXPECT_EQ(std::numeric_limits::has_denorm, std::denorm_absent); EXPECT_EQ(std::numeric_limits::has_denorm_loss, false); #endif EXPECT_EQ(std::numeric_limits::round_style, std::round_toward_zero); EXPECT_EQ(std::numeric_limits::is_iec559, false); EXPECT_EQ(std::numeric_limits::is_bounded, true); EXPECT_EQ(std::numeric_limits::radix, 2); EXPECT_EQ(std::numeric_limits::min_exponent, 0); EXPECT_EQ(std::numeric_limits::min_exponent10, 0); EXPECT_EQ(std::numeric_limits::max_exponent, 0); EXPECT_EQ(std::numeric_limits::max_exponent10, 0); EXPECT_EQ(std::numeric_limits::traps, true); EXPECT_EQ(std::numeric_limits::tinyness_before, false); EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0); EXPECT_EQ(static_cast(std::numeric_limits::round_error()), 0); EXPECT_EQ(static_cast(std::numeric_limits::infinity()), 0); EXPECT_EQ(static_cast(std::numeric_limits::quiet_NaN()), 0); EXPECT_EQ(static_cast(std::numeric_limits::signaling_NaN()), 0); EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), 0); } TYPED_TEST(IntNTest, TypeTraits) { using IntN = TypeParam; EXPECT_TRUE(std::is_trivially_copyable_v); EXPECT_TRUE(std::is_default_constructible_v); EXPECT_TRUE(std::is_nothrow_constructible_v); } TYPED_TEST(IntNTest, CreateAndAssign) { using IntN = TypeParam; // Constructors. EXPECT_EQ(IntN(), IntN(0)); IntN a(1); EXPECT_EQ(a, IntN(1)); IntN b(std::move(a)); EXPECT_EQ(b, IntN(1)); // Assignments. EXPECT_EQ(a = IntN(2), IntN(2)); EXPECT_EQ(b = a, IntN(2)); EXPECT_EQ((a = IntN(3), b = std::move(a)), IntN(3)); } // To ensure an expression is evaluated in a constexpr context, // we use the trick of inserting the expression in a template // parameter. template struct ConstexprEvaluator { static constexpr bool val = true; }; // To avoid warnings about unused left-side of comma expressions, // we additionally pass the expression through a constexpr function. template constexpr void ConstexprEvaluatorFunc(T&&) {} #define TEST_CONSTEXPR(expr) \ do { \ EXPECT_TRUE((ConstexprEvaluator<(ConstexprEvaluatorFunc(expr), 1)>::val)); \ } while (false) TYPED_TEST(IntNTest, Constexpr) { TEST_CONSTEXPR(int4(0)); TEST_CONSTEXPR(static_cast(int4(0))); TEST_CONSTEXPR(-int4(1)); TEST_CONSTEXPR(int4(0) + int4(1)); TEST_CONSTEXPR(int4(1) - int4(0)); TEST_CONSTEXPR(int4(0) * int4(1)); TEST_CONSTEXPR(int4(0) / int4(1)); TEST_CONSTEXPR(int4(0) % int4(1)); TEST_CONSTEXPR(int4(1) & int4(0xF)); TEST_CONSTEXPR(int4(1) | int4(0xF)); TEST_CONSTEXPR(int4(1) ^ int4(0xF)); TEST_CONSTEXPR(~int4(1)); TEST_CONSTEXPR(int4(1) >> 1); TEST_CONSTEXPR(int4(1) << 1); TEST_CONSTEXPR(int4(1) == int4(1)); TEST_CONSTEXPR(int4(1) != int4(1)); TEST_CONSTEXPR(int4(1) < int4(1)); TEST_CONSTEXPR(int4(1) > int4(1)); TEST_CONSTEXPR(int4(1) <= int4(1)); TEST_CONSTEXPR(int4(1) >= int4(1)); TEST_CONSTEXPR(++int4(1)); TEST_CONSTEXPR(int4(1)++); TEST_CONSTEXPR(--int4(1)); TEST_CONSTEXPR(int4(1)--); TEST_CONSTEXPR(int4(1) += int4(2)); TEST_CONSTEXPR(int4(1) -= int4(2)); TEST_CONSTEXPR(int4(1) *= int4(2)); TEST_CONSTEXPR(int4(1) /= int4(2)); TEST_CONSTEXPR(int4(1) %= int4(2)); TEST_CONSTEXPR(int4(1) &= int4(2)); TEST_CONSTEXPR(int4(1) |= int4(2)); TEST_CONSTEXPR(int4(1) ^= int4(2)); TEST_CONSTEXPR(int4(1) >>= 1); TEST_CONSTEXPR(int4(1) <<= 1); } template IntN CreateIntNWithRandomHighBits(int val) { return Eigen::numext::bit_cast(static_cast( val | (Eigen::internal::random() << IntN::bits))); } TYPED_TEST(IntNTest, Casts) { using IntN = TypeParam; // Explicit integer types. if constexpr (IntN::bits == 4) { EXPECT_EQ(static_cast(IntN(4)), 4); EXPECT_EQ(static_cast(IntN(5)), 5); EXPECT_EQ(static_cast(IntN(6)), 6); EXPECT_EQ(static_cast(IntN(7)), 7); EXPECT_EQ(static_cast(IntN(1)), 1); } // Implicit conversion to optional. std::optional c = IntN(0); EXPECT_EQ(c, 0); // Loop through all valid values. for (int i = static_cast(std::numeric_limits::min()); i <= static_cast(std::numeric_limits::max()); ++i) { // Round-trip. EXPECT_EQ(static_cast(CreateIntNWithRandomHighBits(i)), i); // Float truncation. for (int j = 1; j < 10; ++j) { float offset = -1.f + j * 1.f / 5; float f = i + offset; EXPECT_EQ(IntN(f), IntN(static_cast(f))); } } } TYPED_TEST(IntNTest, Operators) { using IntN = TypeParam; for (int i = static_cast(std::numeric_limits::min()); i <= static_cast(std::numeric_limits::max()); ++i) { IntN x = CreateIntNWithRandomHighBits(i); EXPECT_EQ(-x, IntN(-i)); EXPECT_EQ(~x, IntN(~i)); IntN a; EXPECT_EQ((a = x, ++a), IntN(i + 1)); EXPECT_EQ(a, IntN(i + 1)); EXPECT_EQ((a = x, a++), IntN(i)); EXPECT_EQ(a, IntN(i + 1)); EXPECT_EQ((a = x, --a), IntN(i - 1)); EXPECT_EQ(a, IntN(i - 1)); EXPECT_EQ((a = x, a--), IntN(i)); EXPECT_EQ(a, IntN(i - 1)); for (int j = static_cast(std::numeric_limits::min()); j <= static_cast(std::numeric_limits::max()); ++j) { IntN y = CreateIntNWithRandomHighBits(j); EXPECT_EQ(x + y, IntN(i + j)); EXPECT_EQ(x - y, IntN(i - j)); EXPECT_EQ(x * y, IntN(i * j)); if (j != 0) { EXPECT_EQ(x / y, IntN(i / j)); EXPECT_EQ(x % y, IntN(i % j)); } EXPECT_EQ(x & y, IntN(i & j)); EXPECT_EQ(x | y, IntN(i | j)); EXPECT_EQ(x ^ y, IntN(i ^ j)); EXPECT_EQ(x == y, i == j); EXPECT_EQ(x != y, i != j); EXPECT_EQ(x < y, i < j); EXPECT_EQ(x > y, i > j); EXPECT_EQ(x <= y, i <= j); EXPECT_EQ(x >= y, i >= j); EXPECT_EQ(x == static_cast(j), i == j); EXPECT_EQ(x != static_cast(j), i != j); EXPECT_EQ(x < static_cast(j), i < j); EXPECT_EQ(x > static_cast(j), i > j); EXPECT_EQ(x <= static_cast(j), i <= j); EXPECT_EQ(x >= static_cast(j), i >= j); EXPECT_EQ(static_cast(j) == x, j == i); EXPECT_EQ(static_cast(j) != x, j != i); EXPECT_EQ(static_cast(j) < x, j < i); EXPECT_EQ(static_cast(j) > x, j > i); EXPECT_EQ(static_cast(j) <= x, j <= i); EXPECT_EQ(static_cast(j) >= x, j >= i); EXPECT_EQ((a = x, a += y), IntN(i + j)); EXPECT_EQ((a = x, a -= y), IntN(i - j)); EXPECT_EQ((a = x, a *= y), IntN(i * j)); if (j != 0) { EXPECT_EQ((a = x, a /= y), IntN(i / j)); EXPECT_EQ((a = x, a %= y), IntN(i % j)); } EXPECT_EQ((a = x, a &= y), IntN(i & j)); EXPECT_EQ((a = x, a |= y), IntN(i | j)); EXPECT_EQ((a = x, a ^= y), IntN(i ^ j)); } for (int amount = 0; amount < IntN::bits; ++amount) { EXPECT_EQ(x >> amount, IntN(i >> amount)); EXPECT_EQ(x << amount, IntN(i << amount)); EXPECT_EQ((a = x, a >>= amount), IntN(i >> amount)); EXPECT_EQ((a = x, a <<= amount), IntN(i << amount)); } } } TYPED_TEST(IntNTest, ToString) { using IntN = TypeParam; for (int i = static_cast(std::numeric_limits::min()); i <= static_cast(std::numeric_limits::max()); ++i) { IntN x = CreateIntNWithRandomHighBits(i); std::stringstream ss; ss << x; EXPECT_EQ(ss.str(), std::to_string(i)); EXPECT_EQ(x.ToString(), std::to_string(i)); } } struct CustomInt { constexpr CustomInt() : x(0) {} constexpr CustomInt(int x) : x(x) {} // NOLINTNEXTLINE(google-explicit-constructor) constexpr operator int() const { return x; } constexpr bool operator==(const CustomInt& other) const { return x == other.x; } private: int x; }; #define GEN_DEST_TYPES(Type) \ std::pair, std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair #define GEN_TYPE_PAIRS() \ GEN_DEST_TYPES(int1), GEN_DEST_TYPES(uint1), GEN_DEST_TYPES(int2), \ GEN_DEST_TYPES(uint2), GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4) using IntNCastTypePairs = ::testing::Types; template class IntNCastTest : public ::testing::Test {}; // Helper utility for prettier test names. struct IntNCastTestParamNames { template static std::string GetName(int idx) { using first_type = typename TypeParam::first_type; using second_type = typename TypeParam::second_type; return ::testing::internal::GetTypeName() + "_" + ::testing::internal::GetTypeName(); } }; TYPED_TEST_SUITE(IntNCastTest, IntNCastTypePairs, IntNCastTestParamNames); TYPED_TEST(IntNCastTest, CastThroughInt) { using IntN = typename TypeParam::first_type; using DestType = typename TypeParam::second_type; for (int i = 0; i < (1 << IntN::bits); ++i) { IntN x = CreateIntNWithRandomHighBits(i); DestType dest = static_cast(x); DestType expected = static_cast(static_cast(x)); EXPECT_EQ(dest, expected); } } TYPED_TEST(IntNCastTest, DeviceCast) { using IntN = typename TypeParam::first_type; using DestType = typename TypeParam::second_type; #if defined(EIGEN_USE_GPU) Eigen::GpuStreamDevice stream; Eigen::GpuDevice device(&stream); #elif defined(EIGEN_USE_THREADS) constexpr int kThreads = 4; Eigen::ThreadPool tp(kThreads); Eigen::ThreadPoolDevice device(&tp, kThreads); #else Eigen::DefaultDevice device; #endif const int kNumElems = 256; // Allocate device buffers and create device tensors. IntN* src_device_buffer = (IntN*)device.allocate(kNumElems * sizeof(IntN)); DestType* dst_device_buffer = (DestType*)device.allocate(kNumElems * sizeof(DestType)); Eigen::TensorMap, Eigen::Aligned> src_device( src_device_buffer, kNumElems); Eigen::TensorMap, Eigen::Aligned> dst_device( dst_device_buffer, kNumElems); // Allocate host buffers and initialize src memory. Eigen::Tensor src_cpu(kNumElems); Eigen::Tensor dst_cpu(kNumElems); for (int i = 0; i < kNumElems; ++i) { src_cpu(i) = Eigen::numext::bit_cast(static_cast(i)); } // Transfer data to device, perform a cast to DestType, then transfer result // back to host. device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), kNumElems * sizeof(IntN)); dst_device.device(device) = src_device.template cast(); device.memcpyDeviceToHost(dst_cpu.data(), dst_device_buffer, kNumElems * sizeof(DestType)); device.synchronize(); for (int i = 0; i < kNumElems; ++i) { DestType expected = static_cast(src_cpu(i)); EXPECT_EQ(dst_cpu(i), expected); } // Cast back from DestType to IntN. // First clear out the device src buffer, since that will be the destination. src_cpu.setZero(); device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), kNumElems * sizeof(IntN)); src_device.device(device) = dst_device.template cast(); device.memcpyDeviceToHost(src_cpu.data(), src_device_buffer, kNumElems * sizeof(IntN)); device.synchronize(); for (int i = 0; i < kNumElems; ++i) { IntN expected = static_cast(dst_cpu(i)); EXPECT_EQ(src_cpu(i), expected); } // Clean up. device.deallocate(src_device_buffer); device.deallocate(dst_device_buffer); device.synchronize(); } } // namespace } // namespace ml_dtypes jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/intn_test.py000066400000000000000000000327141510671665600230670ustar00rootroot00000000000000# Copyright 2022 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test cases for int4 types.""" # pylint: disable=g-complex-comprehension import contextlib import copy import operator import pickle import warnings from absl.testing import absltest from absl.testing import parameterized import ml_dtypes from multi_thread_utils import multi_threaded import numpy as np int2 = ml_dtypes.int2 int4 = ml_dtypes.int4 uint2 = ml_dtypes.uint2 uint4 = ml_dtypes.uint4 INTN_TYPES = [int2, int4, uint2, uint4] VALUES = { int2: list(range(-2, 2)), int4: list(range(-8, 8)), uint2: list(range(0, 4)), uint4: list(range(0, 16)), } FLOAT_TYPES = [ ml_dtypes.bfloat16, ml_dtypes.float4_e2m1fn, ml_dtypes.float6_e2m3fn, ml_dtypes.float6_e3m2fn, ml_dtypes.float8_e3m4, ml_dtypes.float8_e4m3, ml_dtypes.float8_e4m3b11fnuz, ml_dtypes.float8_e4m3fn, ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e5m2, ml_dtypes.float8_e5m2fnuz, # No casts to e8m0fnu for now. # ml_dtypes.float8_e8m0fnu, ] @contextlib.contextmanager def ignore_warning(**kw): with warnings.catch_warnings(): warnings.filterwarnings("ignore", **kw) yield # Tests for the Python scalar type @multi_threaded(num_workers=3) class ScalarTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) def testModuleName(self, scalar_type): self.assertEqual(scalar_type.__module__, "ml_dtypes") @parameterized.product(scalar_type=INTN_TYPES) def testPickleable(self, scalar_type): # https://github.com/jax-ml/jax/discussions/8505 x = np.arange(10, dtype=scalar_type) serialized = pickle.dumps(x) x_out = pickle.loads(serialized) self.assertEqual(x_out.dtype, x.dtype) np.testing.assert_array_equal(x_out.astype(int), x.astype(int)) @parameterized.product( scalar_type=INTN_TYPES, python_scalar=[int, float, np.float16, np.longdouble], ) def testRoundTripToPythonScalar(self, scalar_type, python_scalar): for v in VALUES[scalar_type]: self.assertEqual(v, scalar_type(v)) self.assertEqual(python_scalar(v), python_scalar(scalar_type(v))) self.assertEqual( scalar_type(v), scalar_type(python_scalar(scalar_type(v))) ) @parameterized.product(scalar_type=INTN_TYPES) def testRoundTripNumpyTypes(self, scalar_type): for dtype in [np.int8, np.int32]: for f in VALUES[scalar_type]: self.assertEqual(dtype(f), dtype(scalar_type(dtype(f)))) self.assertEqual(int(dtype(f)), int(scalar_type(dtype(f)))) self.assertEqual(dtype(f), dtype(scalar_type(np.array(f, dtype)))) np.testing.assert_equal( dtype(np.array(VALUES[scalar_type], scalar_type)), np.array(VALUES[scalar_type], dtype), ) @parameterized.product(scalar_type=INTN_TYPES) def testStr(self, scalar_type): for value in VALUES[scalar_type]: self.assertEqual(str(value), str(scalar_type(value))) @parameterized.product(scalar_type=INTN_TYPES) def testRepr(self, scalar_type): for value in VALUES[scalar_type]: self.assertEqual(str(value), str(scalar_type(value))) @parameterized.product(scalar_type=INTN_TYPES) def testItem(self, scalar_type): self.assertIsInstance(scalar_type(1).item(), int) self.assertEqual(scalar_type(1).item(), 1) @parameterized.product(scalar_type=INTN_TYPES) def testHash(self, scalar_type): for v in VALUES[scalar_type]: self.assertEqual(hash(v), hash(scalar_type(v)), msg=v) @parameterized.product( scalar_type=INTN_TYPES, op=[ operator.le, operator.lt, operator.eq, operator.ne, operator.ge, operator.gt, ], ) def testComparison(self, scalar_type, op): for v in VALUES[scalar_type]: for w in VALUES[scalar_type]: result = op(scalar_type(v), scalar_type(w)) self.assertEqual(op(v, w), result) self.assertIsInstance(result, np.bool_) @parameterized.product( scalar_type=INTN_TYPES, op=[ operator.neg, operator.pos, ], ) def testUnop(self, scalar_type, op): for v in VALUES[scalar_type]: out = op(scalar_type(v)) self.assertIsInstance(out, scalar_type) self.assertEqual(scalar_type(op(v)), out, msg=v) @parameterized.product( scalar_type=INTN_TYPES, op=[ operator.add, operator.sub, operator.mul, operator.floordiv, operator.mod, ], ) def testBinop(self, scalar_type, op): for v in VALUES[scalar_type]: for w in VALUES[scalar_type]: if w == 0 and op in [operator.floordiv, operator.mod]: with self.assertRaises(ZeroDivisionError): op(scalar_type(v), scalar_type(w)) else: out = op(scalar_type(v), scalar_type(w)) self.assertIsInstance(out, scalar_type) self.assertEqual(scalar_type(op(v, w)), out, msg=(v, w)) CAST_DTYPES = [ np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32, np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong, ] + INTN_TYPES @parameterized.product(a=INTN_TYPES, b=CAST_DTYPES) def testCanCast(self, a, b): allowed_casts = [ (np.bool_, int2), (np.bool_, int4), (int2, int2), (int2, np.int8), (int2, np.int16), (int2, np.int32), (int2, np.int64), (int2, np.float16), (int2, np.float32), (int2, np.float64), (int2, np.complex64), (int2, np.complex128), (int2, int4), (int4, int4), (int4, np.int8), (int4, np.int16), (int4, np.int32), (int4, np.int64), (int4, np.float16), (int4, np.float32), (int4, np.float64), (int4, np.complex64), (int4, np.complex128), (np.bool_, uint2), (np.bool_, uint4), (uint2, uint2), (uint2, np.int8), (uint2, np.int16), (uint2, np.int32), (uint2, np.int64), (uint2, np.uint8), (uint2, np.uint16), (uint2, np.uint32), (uint2, np.uint64), (uint2, np.float16), (uint2, np.float32), (uint2, np.float64), (uint2, np.complex64), (uint2, np.complex128), (uint2, uint4), (uint4, uint4), (uint4, np.int8), (uint4, np.int16), (uint4, np.int32), (uint4, np.int64), (uint4, np.uint8), (uint4, np.uint16), (uint4, np.uint32), (uint4, np.uint64), (uint4, np.float16), (uint4, np.float32), (uint4, np.float64), (uint4, np.complex64), (uint4, np.complex128), ] allowed_casts += [(a, b) for a in INTN_TYPES for b in FLOAT_TYPES] self.assertEqual( ((a, b) in allowed_casts), np.can_cast(a, b, casting="safe") ) @parameterized.product(scalar_type=INTN_TYPES) def testIssubdtype(self, scalar_type): # In the future, we may want to make these more specific (e.g. use # np.number or np.integer instead of np.generic) by changing the # base in RegisterIntNDtype. self.assertTrue(np.issubdtype(scalar_type, np.generic)) self.assertTrue(np.issubdtype(np.dtype(scalar_type), np.generic)) @parameterized.product(scalar_type=INTN_TYPES) def testCastToDtype(self, scalar_type): name = scalar_type.__name__ dt = np.dtype(scalar_type) self.assertIs(dt.type, scalar_type) self.assertEqual(dt.name, name) self.assertEqual(repr(dt), f"dtype({name})") @parameterized.product(scalar_type=INTN_TYPES) def testCastFailure(self, scalar_type): with self.assertRaises(ValueError): scalar_type(np.nan) with self.assertRaises(OverflowError): scalar_type(np.inf) with self.assertRaises(OverflowError): scalar_type(1e10) with self.assertRaises(ValueError): np.array(np.nan, dtype=scalar_type) with self.assertRaises(OverflowError): np.array(np.inf, dtype=scalar_type) with self.assertRaises(OverflowError): np.array(1e10, dtype=scalar_type) # But these shouldn't raise exceptions. np.array(np.nan).astype(scalar_type) np.array(np.inf).astype(scalar_type) with np.errstate(invalid="ignore"): np.array(1e10).astype(scalar_type) # Tests for the Python scalar type @multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"]) class ArrayTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) def testDtype(self, scalar_type): self.assertEqual(scalar_type, np.dtype(scalar_type)) @parameterized.product(scalar_type=INTN_TYPES) def testHash(self, scalar_type): h = hash(np.dtype(scalar_type)) self.assertEqual(h, hash(np.dtype(scalar_type.dtype))) self.assertEqual(h, hash(np.dtype(scalar_type.__name__))) @parameterized.product(scalar_type=INTN_TYPES) def testDeepCopyDoesNotAlterHash(self, scalar_type): # For context, see https://github.com/jax-ml/jax/issues/4651. If the hash # value of the type descriptor is not initialized correctly, a deep copy # can change the type hash. dtype = np.dtype(scalar_type) h = hash(dtype) _ = copy.deepcopy(dtype) self.assertEqual(h, hash(dtype)) @parameterized.product(scalar_type=INTN_TYPES) def testArray(self, scalar_type): if scalar_type == int2: x = np.array([[-2, 1, 0, 1]], dtype=scalar_type) self.assertEqual("[[-2 1 0 1]]", str(x)) else: x = np.array([[1, 2, 3]], dtype=scalar_type) self.assertEqual("[[1 2 3]]", str(x)) self.assertEqual(scalar_type, x.dtype) np.testing.assert_array_equal(x, x) self.assertTrue((x == x).all()) # pylint: disable=comparison-with-itself @parameterized.product( scalar_type=INTN_TYPES, ufunc=[np.nonzero, np.logical_not, np.argmax, np.argmin], ) def testUnaryPredicateUfunc(self, scalar_type, ufunc): x = np.array(VALUES[scalar_type]) y = np.array(VALUES[scalar_type], dtype=scalar_type) # Compute `ufunc(y)` first so we don't get lucky by reusing memory # initialized by `ufunc(x)`. y_result = ufunc(y) x_result = ufunc(x) np.testing.assert_array_equal(x_result, y_result) @parameterized.product( scalar_type=INTN_TYPES, ufunc=[ np.less, np.less_equal, np.greater, np.greater_equal, np.equal, np.not_equal, np.logical_and, np.logical_or, np.logical_xor, ], ) def testPredicateUfuncs(self, scalar_type, ufunc): x = np.array(VALUES[scalar_type]) y = np.array(VALUES[scalar_type], dtype=scalar_type) np.testing.assert_array_equal( ufunc(x[:, None], x[None, :]), ufunc(y[:, None], y[None, :]), ) @parameterized.product( scalar_type=INTN_TYPES, dtype=[ np.float16, np.float32, np.float64, np.longdouble, np.int8, np.int16, np.int32, np.int64, np.complex64, np.complex128, np.clongdouble, np.uint8, np.uint16, np.uint32, np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong, ], ) def testCasts(self, scalar_type, dtype): x_orig = np.array(VALUES[scalar_type]) x = np.array(VALUES[scalar_type]).astype(dtype) x = np.where(x == x_orig, x, np.zeros_like(x)) y = x.astype(scalar_type) z = y.astype(dtype) self.assertTrue(np.all(x == y), msg=(x, y)) self.assertEqual(scalar_type, y.dtype) self.assertTrue(np.all(x == z)) self.assertEqual(dtype, z.dtype) # TODO(phawkins): ideally we would also allow unsafe casts between custom # types, but I'm unable to figure out how to convince NumPy to treat custom # casts as unsafe. @parameterized.product( types=[(int2, int4), (uint2, uint4)] + [ (a, b) for a in INTN_TYPES for b in FLOAT_TYPES if (a, b) not in [ (int4, ml_dtypes.float6_e2m3fn), (uint4, ml_dtypes.float6_e2m3fn), ] ] ) def testCastBetweenCustomTypes(self, types): a, b = types x = np.array(VALUES[a], dtype=a) y = x.astype(b) np.testing.assert_array_equal( np.array(VALUES[a], dtype=b).astype(np.int32), y.astype(np.int32) ) @parameterized.product( scalar_type=INTN_TYPES, ufunc=[ np.add, np.subtract, np.multiply, np.floor_divide, np.remainder, ], ) @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") def testBinaryUfuncs(self, scalar_type, ufunc): x = np.array(VALUES[scalar_type]) y = np.array(VALUES[scalar_type], dtype=scalar_type) np.testing.assert_array_equal( ufunc(x[:, None], x[None, :]).astype(scalar_type), ufunc(y[:, None], y[None, :]), ) if __name__ == "__main__": absltest.main() jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/metadata_test.py000066400000000000000000000022731510671665600236740ustar00rootroot00000000000000# Copyright 2022 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from importlib import metadata from absl.testing import absltest import ml_dtypes from multi_thread_utils import multi_threaded @multi_threaded(num_workers=3) class CustomFloatTest(absltest.TestCase): def test_version_matches_package_metadata(self): try: ml_dtypes_metadata = metadata.metadata("ml_dtypes") except ImportError as err: raise absltest.SkipTest("Package metadata not found") from err metadata_version = ml_dtypes_metadata["version"] package_version = ml_dtypes.__version__ self.assertEqual(metadata_version, package_version) if __name__ == "__main__": absltest.main() jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/multi_thread_utils.py000066400000000000000000000033671510671665600247630ustar00rootroot00000000000000# Copyright 2024 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for multi-threaded tests.""" import concurrent.futures import functools from typing import Optional def multi_threaded(*, num_workers: int, skip_tests: Optional[list[str]] = None): """Decorator that runs a test in a multi-threaded environment.""" def decorator(test_cls): for name, test_fn in test_cls.__dict__.copy().items(): if not (name.startswith("test") and callable(test_fn)): continue if skip_tests is not None: if any(test_name in name for test_name in skip_tests): continue @functools.wraps(test_fn) # pylint: disable=cell-var-from-loop def multi_threaded_test_fn(*args, __test_fn__=test_fn, **kwargs): with concurrent.futures.ThreadPoolExecutor( max_workers=num_workers ) as executor: futures = [] for _ in range(num_workers): futures.append(executor.submit(__test_fn__, *args, **kwargs)) # We should call future.result() to re-raise an exception if test has # failed list(f.result() for f in futures) setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn) return test_cls return decorator jax-ml-ml_dtypes-882eb0f/ml_dtypes/tests/mxfloat_test.cc000066400000000000000000000301531510671665600235210ustar00rootroot00000000000000/* Copyright 2024 The ml_dtypes Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "ml_dtypes/include/mxfloat.h" #include namespace ml_dtypes { namespace { TEST(FloatMXe2m3Test, NumericLimits) { using limits = std::numeric_limits; EXPECT_EQ(static_cast(limits::min()), 1.0); EXPECT_EQ(static_cast(limits::max()), 7.5); EXPECT_EQ(static_cast(limits::lowest()), -7.5); EXPECT_EQ(static_cast(limits::epsilon()), 0.125); EXPECT_EQ(static_cast(limits::round_error()), 0.25); EXPECT_EQ(static_cast(limits::denorm_min()), 0.125); EXPECT_EQ(limits::digits, 4); EXPECT_EQ(limits::digits10, 0); EXPECT_EQ(limits::max_digits10, 3); EXPECT_EQ(limits::min_exponent, 1); EXPECT_EQ(limits::min_exponent10, 0); EXPECT_EQ(limits::max_exponent, 3); EXPECT_EQ(limits::max_exponent10, 0); EXPECT_EQ(limits::is_iec559, false); EXPECT_EQ(limits::has_infinity, false); EXPECT_EQ(limits::has_quiet_NaN, false); EXPECT_EQ(limits::has_signaling_NaN, false); } TEST(FloatMXe3m2Test, NumericLimits) { using limits = std::numeric_limits; EXPECT_EQ(static_cast(limits::min()), 0.25); EXPECT_EQ(static_cast(limits::max()), 28.0); EXPECT_EQ(static_cast(limits::lowest()), -28.0); EXPECT_EQ(static_cast(limits::epsilon()), 0.25); EXPECT_EQ(static_cast(limits::round_error()), 1.0); EXPECT_EQ(static_cast(limits::denorm_min()), 0.0625); EXPECT_EQ(limits::digits, 3); EXPECT_EQ(limits::digits10, 0); EXPECT_EQ(limits::max_digits10, 2); EXPECT_EQ(limits::min_exponent, -1); EXPECT_EQ(limits::min_exponent10, 0); EXPECT_EQ(limits::max_exponent, 5); EXPECT_EQ(limits::max_exponent10, 1); EXPECT_EQ(limits::is_iec559, false); EXPECT_EQ(limits::has_infinity, false); EXPECT_EQ(limits::has_quiet_NaN, false); EXPECT_EQ(limits::has_signaling_NaN, false); } TEST(Float4e2m1Test, NumericLimits) { using limits = std::numeric_limits; EXPECT_EQ(static_cast(limits::min()), 1.0); EXPECT_EQ(static_cast(limits::max()), 6.0); EXPECT_EQ(static_cast(limits::lowest()), -6.0); EXPECT_EQ(static_cast(limits::epsilon()), 0.5); EXPECT_EQ(static_cast(limits::round_error()), 1.0); EXPECT_EQ(static_cast(limits::denorm_min()), 0.5); EXPECT_EQ(limits::digits, 2); EXPECT_EQ(limits::digits10, 0); EXPECT_EQ(limits::max_digits10, 2); EXPECT_EQ(limits::min_exponent, 1); EXPECT_EQ(limits::min_exponent10, 0); EXPECT_EQ(limits::max_exponent, 3); EXPECT_EQ(limits::max_exponent10, 0); EXPECT_EQ(limits::is_iec559, false); EXPECT_EQ(limits::has_infinity, false); EXPECT_EQ(limits::has_quiet_NaN, false); EXPECT_EQ(limits::has_signaling_NaN, false); } template constexpr int NumValues() { return 1 << T::kBits; } template class FloatMXTest : public ::testing::Test {}; struct FloatMXTestNameGenerator { template static std::string GetName(int) { if constexpr (std::is_same_v) return "float6_e2m3fn"; if constexpr (std::is_same_v) return "float6_e3m2fn"; if constexpr (std::is_same_v) return "float4_e2m1fn"; } }; using FloatMXTypes = ::testing::Types; TYPED_TEST_SUITE(FloatMXTest, FloatMXTypes, FloatMXTestNameGenerator); TYPED_TEST(FloatMXTest, NoInfinity) { using FloatMX = TypeParam; EXPECT_EQ(static_cast(INFINITY), std::numeric_limits::max()); EXPECT_EQ(static_cast(-INFINITY), std::numeric_limits::lowest()); } TYPED_TEST(FloatMXTest, Negate) { using FloatMX = TypeParam; int sign_bit = 1 << (FloatMX::kBits - 1); for (int i = 0; i < sign_bit; ++i) { FloatMX pos = FloatMX::FromRep(i); FloatMX neg = FloatMX::FromRep(i | sign_bit); EXPECT_EQ((-pos).rep(), neg.rep()); EXPECT_EQ((-neg).rep(), pos.rep()); } } TYPED_TEST(FloatMXTest, Signbit) { using FloatMX = TypeParam; FloatMX one(1.0); EXPECT_EQ(Eigen::numext::signbit(one).rep(), 0x00); EXPECT_EQ(Eigen::numext::signbit(-one).rep(), 0xff); } TYPED_TEST(FloatMXTest, BitCasts) { using FloatMX = TypeParam; FloatMX x = FloatMX::FromRep(0x11); EXPECT_EQ(Eigen::numext::bit_cast(x), x.rep()); EXPECT_EQ(Eigen::numext::bit_cast(x.rep()), x); } TYPED_TEST(FloatMXTest, UpCasts) { using FloatMX = TypeParam; for (int i = 0; i < NumValues(); ++i) { FloatMX mx = FloatMX::FromRep(i); double f64 = static_cast(mx); float f32 = static_cast(mx); Eigen::bfloat16 bf16 = static_cast(mx); Eigen::half f16 = static_cast(mx); EXPECT_EQ(f64, f32) << i; EXPECT_EQ(f32, bf16) << i; EXPECT_EQ(bf16, f16) << i; } } TYPED_TEST(FloatMXTest, DownCasts) { using FloatMX = TypeParam; for (int i = 0; i < NumValues(); ++i) { float x = static_cast(FloatMX::FromRep(i)); FloatMX f64 = static_cast(static_cast(x)); FloatMX f32 = static_cast(static_cast(x)); FloatMX bf16 = static_cast(static_cast(x)); FloatMX f16 = static_cast(static_cast(x)); EXPECT_EQ(f64.rep(), i); EXPECT_EQ(f32.rep(), i); EXPECT_EQ(bf16.rep(), i); EXPECT_EQ(f16.rep(), i); } } TYPED_TEST(FloatMXTest, ConvertFromWithSaturation) { using FloatMX = TypeParam; FloatMX upper = FloatMX::template ConvertFrom( static_cast(std::numeric_limits::max()) * 2); EXPECT_EQ(upper, std::numeric_limits::max()); FloatMX lower = FloatMX::template ConvertFrom( static_cast(std::numeric_limits::lowest()) * 2); EXPECT_EQ(lower, std::numeric_limits::lowest()); } TYPED_TEST(FloatMXTest, ConvertFromWithTruncation) { using FloatMX = TypeParam; // Truncation and rounding of a number ever-so-slightly less than 2. float less_than_two = Eigen::numext::bit_cast(0x3FFFFFFF); FloatMX truncated = FloatMX::template ConvertFrom( less_than_two); EXPECT_LT(static_cast(truncated), 2); FloatMX rounded = FloatMX::template ConvertFrom( less_than_two); EXPECT_EQ(static_cast(rounded), 2); // Truncation and rounding of a subnormal. int digits = std::numeric_limits::digits; for (int i = 1; i < (1 << (digits - 1)); ++i) { float less_than_subnorm = std::nexttoward(static_cast(FloatMX::FromRep(i)), 0); FloatMX truncated_subnorm = FloatMX::template ConvertFrom( less_than_subnorm); EXPECT_EQ(truncated_subnorm.rep(), i - 1); FloatMX rounded_subnorm = FloatMX::template ConvertFrom( less_than_subnorm); EXPECT_EQ(rounded_subnorm.rep(), i); } } TYPED_TEST(FloatMXTest, ConvertFromRoundToNearest) { using FloatMX = TypeParam; // Try all pairs of values and check the middle point (which should be exactly // representable as a float), as well as adjacent values. for (int i = 1; i < NumValues(); ++i) { FloatMX left = FloatMX::FromRep(i - 1); FloatMX right = FloatMX::FromRep(i); if (!right) continue; // Skip jump to negative zero. float l = static_cast(left); float r = static_cast(right); float m = (l + r) / 2; float m_minus_eps = std::nexttoward(m, l); float m_plus_eps = std::nexttoward(m, r); EXPECT_EQ(static_cast(m).rep(), i & 1 ? left.rep() : right.rep()); EXPECT_EQ(static_cast(m_minus_eps).rep(), left.rep()); EXPECT_EQ(static_cast(m_plus_eps).rep(), right.rep()); } } TYPED_TEST(FloatMXTest, CompareOperator) { using FloatMX = TypeParam; for (int i = 0; i < NumValues(); ++i) { FloatMX a = FloatMX::FromRep(i); for (int j = 0; j < NumValues(); ++j) { FloatMX b = FloatMX::FromRep(j); EXPECT_EQ(a == b, float{a} == float{b}); EXPECT_EQ(a != b, float{a} != float{b}); EXPECT_EQ(a < b, float{a} < float{b}); EXPECT_EQ(a <= b, float{a} <= float{b}); EXPECT_EQ(a > b, float{a} > float{b}); EXPECT_EQ(a >= b, float{a} >= float{b}); } } } #define GEN_FLOAT_TYPE_PAIRS(Type) \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair #define GEN_TEST_TYPE_PAIRS() \ GEN_FLOAT_TYPE_PAIRS(float6_e2m3fn), GEN_FLOAT_TYPE_PAIRS(float6_e3m2fn), \ GEN_FLOAT_TYPE_PAIRS(float4_e2m1fn), \ std::pair, \ std::pair, \ std::pair template class FloatMXCastTest : public ::testing::Test {}; struct FloatMXCastTestNameGenerator { template static std::string GetName(int) { std::string first_name = ::testing::internal::GetTypeName(); std::string second_name = ::testing::internal::GetTypeName(); return first_name + "_" + second_name; } }; using FloatMXCastTypePairs = ::testing::Types; TYPED_TEST_SUITE(FloatMXCastTest, FloatMXCastTypePairs, FloatMXCastTestNameGenerator); TYPED_TEST(FloatMXCastTest, FromFloatMX) { using FloatMX = typename TypeParam::first_type; using DestType = typename TypeParam::second_type; for (int i = 0; i < NumValues(); ++i) { FloatMX mx = FloatMX::FromRep(i); DestType converted = static_cast(mx); DestType expected = static_cast(static_cast(mx)); if (Eigen::numext::isnan(expected)) { EXPECT_TRUE(Eigen::numext::isnan(converted)); } else { EXPECT_EQ(converted, expected); } } } TYPED_TEST(FloatMXCastTest, ToFloatMX) { using FloatMX = typename TypeParam::first_type; using SrcType = typename TypeParam::second_type; using SrcTraits = typename float8_internal::Traits; // For float8, iterate over all possible values. // For other floating point types, discard lower mantissa bits that do not // participate in rounding calculation to keep the test size reasonable. constexpr bool is_fp8 = sizeof(SrcType) == 1; int test_bits = SrcTraits::kBits, shift = 0; if (!is_fp8) { int e_bits = test_bits - std::numeric_limits::digits; int m_bits = std::numeric_limits::digits + 1; test_bits = 1 + e_bits + m_bits; shift = sizeof(SrcType) * CHAR_BIT - test_bits; } using BitsType = typename SrcTraits::BitsType; for (int i = 0; i < (1 << test_bits); ++i) { BitsType value = static_cast(i) << shift; SrcType fp = Eigen::numext::bit_cast(value); FloatMX converted = static_cast(fp); FloatMX expected = static_cast(static_cast(fp)); EXPECT_EQ(converted, expected); } } } // namespace } // namespace ml_dtypes jax-ml-ml_dtypes-882eb0f/pyproject.toml000066400000000000000000000034531510671665600202560ustar00rootroot00000000000000[project] name = "ml_dtypes" dynamic = ["version"] # Load from ml_dtypes.__version__. description = "ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning." readme = "README.md" requires-python = ">=3.9" license = "Apache-2.0" license-files = ["LICENSE", "LICENSE.eigen"] authors = [{name = "ml_dtypes authors", email="ml_dtypes@google.com"}] classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Intended Audience :: Science/Research", ] keywords = [] # pip dependencies of the project dependencies = [ # Ensure numpy release supports Python version. "numpy>=1.21", "numpy>=1.21.2; python_version>='3.10'", "numpy>=1.23.3; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", "numpy>=2.1.0; python_version>='3.13'", ] [project.urls] homepage = "https://github.com/jax-ml/ml_dtypes" repository = "https://github.com/jax-ml/ml_dtypes" # Other: `documentation`, `changelog` [project.optional-dependencies] # Development deps (unittest, linting, formating,...) # Installed through `pip install .[dev]` dev = [ "absl-py", "pytest", "pytest-xdist", "pylint>=2.6.0", "pyink", ] [tool.pyink] # Formatting configuration to follow Google style-guide line-length = 80 preview = true pyink-indentation = 2 pyink-use-majority-quotes = true [build-system] requires = [ # We build against the most recent supported NumPy 2.0 release; # see https://github.com/numpy/numpy/issues/27265 "numpy~=2.0", "setuptools~=80.9.0", ] build-backend = "setuptools.build_meta" [tool.setuptools] packages = ["ml_dtypes"] include-package-data = false [tool.setuptools.dynamic] version = {attr = "ml_dtypes.__version__"} [tool.setuptools.package-data] ml_dtypes = ["py.typed"] jax-ml-ml_dtypes-882eb0f/pytest.ini000066400000000000000000000001401510671665600173610ustar00rootroot00000000000000[pytest] filterwarnings = error ignore:numpy.core._multiarray_umat.*:DeprecationWarning jax-ml-ml_dtypes-882eb0f/setup.py000066400000000000000000000043201510671665600170460ustar00rootroot00000000000000# Copyright 2022 The ml_dtypes Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Setuptool-based build for ml_dtypes.""" import fnmatch import platform import numpy as np from setuptools import Extension from setuptools import setup from setuptools.command.build_py import build_py as build_py_orig if platform.system() == "Windows": COMPILE_ARGS = [ "/std:c++17", "/DEIGEN_MPL2_ONLY", "/EHsc", "/bigobj", ] else: COMPILE_ARGS = [ "-std=c++17", "-DEIGEN_MPL2_ONLY", "-fvisibility=hidden", # -ftrapping-math is necessary because NumPy looks at floating point # exception state to determine whether to emit, e.g., invalid value # warnings. Without this setting, on Mac ARM we see spurious "invalid # value" warnings when running the tests. "-ftrapping-math", ] exclude = ["third_party*"] class build_py(build_py_orig): # pylint: disable=invalid-name def find_package_modules(self, package, package_dir): modules = super().find_package_modules(package, package_dir) return [ # pylint: disable=g-complex-comprehension (pkg, mod, file) for (pkg, mod, file) in modules if not any( fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern) for pattern in exclude ) ] setup( ext_modules=[ Extension( "ml_dtypes._ml_dtypes_ext", [ "ml_dtypes/_src/dtypes.cc", "ml_dtypes/_src/numpy.cc", ], include_dirs=[ "third_party/eigen", ".", np.get_include(), ], extra_compile_args=COMPILE_ARGS, ) ], cmdclass={"build_py": build_py}, ) jax-ml-ml_dtypes-882eb0f/third_party/000077500000000000000000000000001510671665600176665ustar00rootroot00000000000000jax-ml-ml_dtypes-882eb0f/third_party/eigen/000077500000000000000000000000001510671665600207555ustar00rootroot00000000000000