pax_global_header00006660000000000000000000000064141037436540014521gustar00rootroot0000000000000052 comment=dc8b4878d21f96c205aae0360f0a08500057e000 eagerpy-0.30.0/000077500000000000000000000000001410374365400132355ustar00rootroot00000000000000eagerpy-0.30.0/.github/000077500000000000000000000000001410374365400145755ustar00rootroot00000000000000eagerpy-0.30.0/.github/FUNDING.yml000066400000000000000000000011331410374365400164100ustar00rootroot00000000000000# These are supported funding model platforms # github: [jonasrauber] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username custom: ["paypal.me/jonasrauber"] eagerpy-0.30.0/.github/workflows/000077500000000000000000000000001410374365400166325ustar00rootroot00000000000000eagerpy-0.30.0/.github/workflows/docs.yml000066400000000000000000000012761410374365400203130ustar00rootroot00000000000000name: Docs on: push: branches: - master jobs: vuepress: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Install vuepress run: | sudo apt update sudo apt install yarn sudo yarn global add vuepress - name: Build run: | sudo vuepress build working-directory: ./docs - name: Push env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | sudo git init sudo git add -A sudo git commit -m 'deploy' sudo git push -f https://x-access-token:${GITHUB_TOKEN}@github.com/jonasrauber/eagerpy.git master:gh-pages working-directory: ./docs/.vuepress/dist eagerpy-0.30.0/.github/workflows/pypi.yml000066400000000000000000000011321410374365400203330ustar00rootroot00000000000000name: PyPI on: release: types: [created] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Set up Python uses: actions/setup-python@v1 with: python-version: '3.6' - name: Install dependencies run: | python -m pip install --upgrade pip pip install setuptools wheel twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | python setup.py sdist bdist_wheel twine upload dist/* eagerpy-0.30.0/.github/workflows/tests.yml000066400000000000000000000045601410374365400205240ustar00rootroot00000000000000name: Tests on: push: branches: - master pull_request: jobs: build: runs-on: ubuntu-latest strategy: max-parallel: 4 matrix: python-version: [3.6, 3.7] steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - uses: actions/cache@v1 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} restore-keys: | ${{ runner.os }}-pip- - name: Install requirements-dev.txt run: | python -m pip install --upgrade pip setuptools pip install -r requirements-dev.txt - name: flake8 run: | flake8 . --count --show-source --statistics - name: black run: | black --check --verbose . - name: Install package run: | pip install -e . - name: Install requirements.txt run: | function retry-with-backoff() { for BACKOFF in 0 1 2 4 8 16 32 64; do sleep $BACKOFF if "$@"; then return 0 fi done return 1 } retry-with-backoff pip install -r requirements.txt - name: mypy (package) run: | mypy -p eagerpy - name: mypy (tests) run: | mypy tests/ - name: Test with pytest run: | pytest --cov-report term-missing --cov=eagerpy --verbose - name: Test with pytest (NumPy) run: | pytest --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend numpy - name: Test with pytest (PyTorch) run: | pytest --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend pytorch - name: Test with pytest (JAX) run: | pytest --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend jax - name: Test with pytest (TensorFlow) run: | pytest --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend tensorflow - name: Codecov continue-on-error: true env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} run: | codecov - name: Coveralls continue-on-error: true env: COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} run: | coveralls eagerpy-0.30.0/.gitignore000066400000000000000000000043041410374365400152260ustar00rootroot00000000000000.mypy_cache .pytest_cache/ # Created by https://www.gitignore.io/api/osx,vim,linux,python ### Linux ### *~ # temporary files which can be created if a process still has a handle open of a deleted file .fuse_hidden* # KDE directory preferences .directory # Linux trash folder which might appear on any partition or disk .Trash-* # .nfs files are created when an open file is removed but is still being accessed .nfs* ### OSX ### *.DS_Store .AppleDouble .LSOverride # Icon must end with two \r Icon # Thumbnails ._* # Files that might appear in the root of a volume .DocumentRevisions-V100 .fseventsd .Spotlight-V100 .TemporaryItems .Trashes .VolumeIcon.icns .com.apple.timemachine.donotpresent # Directories potentially created on remote AFP share .AppleDB .AppleDesktop Network Trash Folder Temporary Items .apdisk ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # dotenv .env # virtualenv .venv venv/ ENV/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site ### Vim ### # swap [._]*.s[a-v][a-z] [._]*.sw[a-p] [._]s[a-v][a-z] [._]sw[a-p] # session Session.vim # temporary .netrwhist # auto-generated tag files tags # PyCharm .idea/ # Visual Studio Code .vscode/ # End of https://www.gitignore.io/api/osx,vim,linux,python eagerpy-0.30.0/.pre-commit-config.yaml000066400000000000000000000003001410374365400175070ustar00rootroot00000000000000repos: - repo: https://github.com/ambv/black rev: 19.10b0 hooks: - id: black language_version: python3.6 - repo: https://gitlab.com/pycqa/flake8 rev: 3.7.9 hooks: - id: flake8 eagerpy-0.30.0/LICENSE000066400000000000000000000020551410374365400142440ustar00rootroot00000000000000MIT License Copyright (c) 2020 Jonas Rauber Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. eagerpy-0.30.0/MANIFEST.in000066400000000000000000000001011410374365400147630ustar00rootroot00000000000000include eagerpy/VERSION include eagerpy/py.typed include LICENSE eagerpy-0.30.0/Makefile000066400000000000000000000037171410374365400147050ustar00rootroot00000000000000.PHONY: test test: pytest --pdb --cov-report term-missing --cov=eagerpy --verbose pytest --pdb --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend numpy pytest --pdb --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend pytorch pytest --pdb --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend jax pytest --pdb --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend tensorflow pytest --pdb --cov-report term-missing --cov=eagerpy --cov-append --verbose --backend pytorch-gpu .PHONY: black black: black . .PHONY: blackcheck blackcheck: black --check . .PHONY: flake8 flake8: flake8 .PHONY: mypy mypy: mypy -p eagerpy mypy tests/ .PHONY: docs docs: pydocmd generate cd docs && vuepress build --temp /tmp/ .PHONY: installvuepress installvuepress: curl -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - echo "deb https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list sudo apt update && sudo apt install yarn sudo yarn global add vuepress .PHONY: servedocs servedocs: cd docs/.vuepress/dist/ && python3 -m http.server 9999 .PHONY: devdocs devdocs: cd docs && vuepress dev --temp /tmp/ --port 9999 .PHONY: pushdocs pushdocs: cd docs/.vuepress/dist/ && git init && git add -A && git commit -m 'deploy' cd docs/.vuepress/dist/ && git push -f git@github.com:jonasrauber/eagerpy.git master:gh-pages .PHONY: install install: pip3 install -e . .PHONY: devsetup devsetup: pre-commit install .PHONY: build build: python3 setup.py sdist .PHONY: commit commit: git add eagerpy/VERSION git commit -m 'Version $(shell cat eagerpy/VERSION)' .PHONY: release release: build twine upload dist/eagerpy-$(shell cat eagerpy/VERSION).tar.gz .PHONY: pyre pyre: pyre --source-directory . check .PHONY: pytype pytype: pytype . .PHONY: pyright pyright: pyright . .PHONY: mypyreport mypyreport: -mypy . --html-report build python3 -m http.server 9999 eagerpy-0.30.0/README.rst000066400000000000000000000111351410374365400147250ustar00rootroot00000000000000.. raw:: html .. image:: https://badge.fury.io/py/eagerpy.svg :target: https://badge.fury.io/py/eagerpy .. image:: https://codecov.io/gh/jonasrauber/eagerpy/branch/master/graph/badge.svg :target: https://codecov.io/gh/jonasrauber/eagerpy .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/ambv/black ================================================================================== EagerPy: Writing Code That Works Natively with PyTorch, TensorFlow, JAX, and NumPy ================================================================================== `EagerPy `_ is a **Python framework** that lets you write code that automatically works natively with `PyTorch `_, `TensorFlow `_, `JAX `_, and `NumPy `_. EagerPy is **also great when you work with just one framework** but prefer a clean and consistent API that is fully chainable, provides extensive type annotions and lets you write beautiful code. 🔥 Design goals ---------------- - **Native Performance**: EagerPy operations get directly translated into the corresponding native operations. - **Fully Chainable**: All functionality is available as methods on the tensor objects and as EagerPy functions. - **Type Checking**: Catch bugs before running your code thanks to EagerPy's extensive type annotations. 📖 Documentation ----------------- Learn more about EagerPy in the `documentation `_. 🚀 Quickstart -------------- .. code-block:: bash pip install eagerpy EagerPy requires Python 3.6 or newer. Besides that, all essential dependencies are automatically installed. To use it with PyTorch, TensorFlow, JAX, or NumPy, the respective framework needs to be installed separately. These frameworks are not declared as dependencies because not everyone wants to use and thus install all of them and because some of these packages have different builds for different architectures and `CUDA `_ versions. 🎉 Example ----------- .. code-block:: python import torch x = torch.tensor([1., 2., 3., 4., 5., 6.]) import tensorflow as tf x = tf.constant([1., 2., 3., 4., 5., 6.]) import jax.numpy as np x = np.array([1., 2., 3., 4., 5., 6.]) import numpy as np x = np.array([1., 2., 3., 4., 5., 6.]) # No matter which framwork you use, you can use the same code import eagerpy as ep # Just wrap a native tensor using EagerPy x = ep.astensor(x) # All of EagerPy's functionality is available as methods x = x.reshape((2, 3)) x.flatten(start=1).square().sum(axis=-1).sqrt() # or just: x.flatten(1).norms.l2() # and as functions (yes, gradients are also supported!) loss, grad = ep.value_and_grad(loss_fn, x) ep.clip(x + eps * grad, 0, 1) # You can even write functions that work transparently with # Pytorch tensors, TensorFlow tensors, JAX arrays, NumPy arrays def my_universal_function(a, b, c): # Convert all inputs to EagerPy tensors a, b, c = ep.astensors(a, b, c) # performs some computations result = (a + b * c).square() # and return a native tensor return result.raw 🗺 Use cases ------------ `Foolbox Native `_, the latest version of Foolbox, a popular adversarial attacks library, has been rewritten from scratch using EagerPy instead of NumPy to achieve native performance on models developed in PyTorch, TensorFlow and JAX, all with one code base. EagerPy is also used by other frameworks to reduce code duplication (e.g. `GUDHI `_) or to `compare the performance of different frameworks `_. 📄 Citation ------------ If you use EagerPy, please cite our `paper `_ using the this BibTex entry: .. code-block:: @article{rauber2020eagerpy, title={{EagerPy}: Writing Code That Works Natively with {PyTorch}, {TensorFlow}, {JAX}, and {NumPy}}, author={Rauber, Jonas and Bethge, Matthias and Brendel, Wieland}, journal={arXiv preprint arXiv:2008.04175}, year={2020}, url={https://eagerpy.jonasrauber.de}, } 🐍 Compatibility ----------------- We currently test with the following versions: * PyTorch 1.4.0 * TensorFlow 2.1.0 * JAX 0.1.57 * NumPy 1.18.1 eagerpy-0.30.0/docs/000077500000000000000000000000001410374365400141655ustar00rootroot00000000000000eagerpy-0.30.0/docs/.vuepress/000077500000000000000000000000001410374365400161175ustar00rootroot00000000000000eagerpy-0.30.0/docs/.vuepress/config.js000066400000000000000000000016001410374365400177170ustar00rootroot00000000000000module.exports = { title: 'EagerPy', description: 'A unified API for PyTorch, TensorFlow, JAX and NumPy', themeConfig: { nav: [ { text: 'Guide', link: '/guide/' }, { text: 'API', link: '/api/' }, { text: 'GitHub', link: 'https://github.com/jonasrauber/eagerpy' } ], sidebar: [ { title: 'Guide', collapsable: false, children: [ '/guide/', '/guide/getting-started', '/guide/converting', '/guide/generic-functions', '/guide/autodiff', '/guide/examples', '/guide/development', '/guide/citation', ], }, { title: 'API', collapsable: false, children: [ '/api/', ['/api/tensor', 'Tensor'], '/api/lib', '/api/norms', '/api/types', ], }, ], }, } eagerpy-0.30.0/docs/.vuepress/public/000077500000000000000000000000001410374365400173755ustar00rootroot00000000000000eagerpy-0.30.0/docs/.vuepress/public/CNAME000066400000000000000000000000261410374365400201410ustar00rootroot00000000000000eagerpy.jonasrauber.deeagerpy-0.30.0/docs/.vuepress/public/logo.png000066400000000000000000000241231410374365400210450ustar00rootroot00000000000000PNG  IHDRڴSaiCCPkCGColorSpaceDisplayP3(c``RI,(aa``+) rwRR` b >@% 0|/:%5I^bՋD0գd S JSl): ȞbC@$XMH3}VHHIBOGbCnₜJc%VhʢG`(*x%(30s 8, Ě30n~@\;b 'v$%)-rH @=iF`yF'{Vc``ww1Py!e5DeXIfMM*i&iTXtXML:com.adobe.xmp 528 182 1 pH$MIDATx Uřǫ\L< G@E3D͢1,c IN<9I'3E5!2C jD$nbhF$&Q1(ܚv~tUVwꫯ_8 (V N:lɂ;˝Lz|ӎ7wp1+FØΪ?{14(p֚oHg?SCWܡ`( +g1ί}198 -sd^#j^D9N#o?(W Ńtc |qhvE)l6b?8lcB4g&/exyy%Zgl54\ DN[g[n|dNb 7o>e!}I|s~h,.R.*$E?\Yͮ$>T):;llӕ<] v,L KNO;iAڡn< oGfgz;?`1GC8\4u=łMr? VᴽzwBu0=JD;-^>vaڣh8L̓z'.UzWhCӗ-r!_ g=ڹ-?c͞ϣss@Hc ݸ6LȠLv+y[qi|F-ٵ|=ơyw6?@}}} \p뭵Çm:Ϛ^}#@dPWF~ Ya%P ۺ0{DwۇUz{D`%Bġ7~1tȈt#{.T Nx]%c6SKqC~tUKZnƻz|xtHvfL {f/Gǧ@2z1:A~@rIA QoCvSΩ6jyУ`;֟yC%pUt8* p̰!ɳ_VJ}rWt_ۇ /hym5'&^5OEXGNz繛/mK ၚG; }wѾ]eI]y[2>fcʃ)WUD|킇CjifYg6PdIx;3Cm[Ѳ dJA=3GxS?$*rO}aҗ"IaO`٨j8Pvy9R#|xr@t{UICB2zt͊"%N$u!7.Cj=d(< B\y^CC܈PKV]?GMzt,D )feHNC+ g_#$jvo+4r0V XrZ)8Wa؁$!1PTTrC8T3!PS?K2gpDOf i>D zxH>[=dx)(P킇!sPL:30i9EO˅2vZ)=ra[Rё }6ƒoK&#:R`O!6C:gX,->@Ig[UB fxj)h,ZyBKy1B'=9ܛnOHNz w5BenBc̯B Tjp 'A7ij0 Y9t=aKrs ߴ W^'N_1zEN5ԕDL;'sZ* UP`,f0YMőNi< yمP{@KIUC+@K;{jTIG+zmXNDXؤùW묣}ʸ{ZIZOSw5ZWkyٹIiS<`,  &xhgl< <0vmV,oL&8~L63ߍP3(ǩ&Y27we Bv8?euaҮcy~;@SϠ}P A<dF̩tRƒQaۀ QA<fH|y1N ]eO]=^D-|e(H @DC 4% &l7Yx/lJW/; 0h^ 2Ca*b `1x+ࡔ2v<_q?cOĶIcVy ~凲VKBAǂͳϥPGrӃΕoj jV ʩC!1t C< a] dPATB 0AC  Z҇-TKR.|p+;!fc1wE'ࡴX(6i xHT~mm<P^3Dy}KjD@6(O\Rk*[eR`ʒ<(Һ@9r:Ex,F9)Cd x!U蚀`=<` 6!dV4g!&.&ݙয7}U8 @ aRlP(@<;}6X('êA^6k~j jDdɖI<Dࡄ0bόӾYF.b@C'8l1PWo<[^4,3b cKx!fhVb]#UM|~qdG+JXj@Ǝ9z Kx!fh^^xȁGfnx(a^9̽G-eKX,CpB|Qߡcm{<wΥc/c/Vtᚕe`={< -)еŧ|NyX%Yo.YjAࡸI*{jZkzZd.C V)qk]x(kn9^TkSwCyM畯an)ࡴ7UOQ':JZ}J)I)u q<.$ x(J}/y%y-J8Tܾ&;f&53.Dx=xioPd#Ǔ1!}__y;X< ' #smm] ݫw΂oZC ;Q .wλ P @(\U=]z!IENDB`eagerpy-0.30.0/docs/.vuepress/public/logo_small.png000066400000000000000000000152271410374365400222420ustar00rootroot00000000000000PNG  IHDR[aiCCPkCGColorSpaceDisplayP3(c``RI,(aa``+) rwRR` b >@% 0|/:%5I^bՋD0գd S JSl): ȞbC@$XMH3}VHHIBOGbCnₜJc%VhʢG`(*x%(30s 8, Ě30n~@\;b 'v$%)-rH @=iF`yF'{Vc``ww1Py!e5leXIfMM*>FiNHH[i pHYs  iTXtXML:com.adobe.xmp 1 528 182 /t iTIDATx ygvow2%va'RtH|gR.$ #r ˕hle;qdl%IT(:IDPr26<ݽytJ{WߏoYPq֓H8GW^& ƉCV6ؼGߑהI@N ]z9d,/҈sW;S20c| HO H%eixǷݤij/]%~鎷MӤFC7jU(,hD3MBL}% 弐~smkuٓ/%Qdcp8媅QvvM4 T~ug?teԇ?q´~/ULa-wB x@Y>`s.?+-h Y?# q~iDڀñ.cL6t6/ד5`aa_!Ķŀ,Ybg2N69f)cHۣ H DD-]xfdr9je]t49847/'&m{XXLgژwtD5G[ϗѬ`m k̻/dZS@Tphoum2Fѐ-ۂz5 duq߈Fh63#c &100zzz M3 v &Je V,3IۅB, @?hA=qAǰ0X v0$?1]9X.@eS >D@`c.2Yj Sx)6ܩ_PJxh<Xr*kz6Q`~}3J8,<_PV?jN(7A<8|]29p]nuS:zm=BijV6_jnC 4(* PoT\E8HlO;/̺{ZyD?RBf_;~O P`ELb<`BĢwҧm)ø4?Y뽣s0_(K@3lX.@ Vß|Z 2dA5w`.(< U5j%bBuN<;h`(C:@F5q,2=Dyߣ#hf0_O j,L hxu3qYxN&p;phr:IUkcs蹽oTxݛ ,:=O7Wȇ H !pLc EOsw^bd2n j ,t uiG=h6N %P6YsUo? 1͊3b@h ?B>3rЍP,Cѻ2v!x /@^CRR9hD;pCG g9mH8"fR9sE\飿ѐz-`!ĘäP.] uv]c3m @ Aq2c5osEn:9AOBЋQg1`UTtppȇC_jVT"B125Q8`ہ\ '^/)Aba:ݡ:;A5pf¡ZcgeQdʼn ̇öCDc3"$qz =U}p \`!޽gz@S;L :¡K9 V4YDUl>€$Ls`@0sv(f$~ͬzjIb/ k:[V,vmWRMA(8(8LfHCpVU Tr]/),:+8(8T܂pey9f a.8+:s+:ob(8(8TMZ2 p984ro%P΁  +!~CN+IM(0V8bhJٸ< >_U 2ک4Vph~8?j9mٟk1BTEPph 80\q!UWٔCAb@(8(8!Vp1>+($ >CEPpPp=C8x($ '@(8(8Ȟ!Vp PIiBAAtCP6 <@A&Qp(bᜀPpPpD!p8+ dO+84JBAd&CQ@(8(8Ȟ!Vph"8|   (84BAAtCa98h{t Q=]K=9L$ܥ{9/۽.{h.rqMqK'ËDCi¡kpPe/}Zt\7n^|OCKҟFA @I$ 40$sXkq=utFĵ|!x5}GSK/7xa̯z!qє\4YQ4=` t^9v={w Qh`˭J3M]shHaFǶlEAF4¬d{V3|[i+Iod'"} %G9eci:8`tˌX6DZjWp@MF,vcYCL*O4hbeCuz:$U|g@dp 1e{|O/>ťV2>Pk zZ gt]#XK [PHMB@EfPm hU5g:@)17/ :d>3t|QX }kVpBάuhG48Sp1XFG8U3Tr (&xg?!ORRKe޴)v}~+sjcKj-L# V֜ !tᄏ$3bĆ[8$q!+%Qxb~+շpoMOq#m{RS(?j^"ZTs@/ >Q ~ D¬_8/ &9׼N-DMgk~#TV#=EH0B6фf0\jƌl𤋮VQZ4;P(yF_w2e QMSV|kVHY͜IIٱ+z e.>xk̜V^O +"RKX2P(uOGr}q?hcG8"`ԅk^}& > XNC`AT,h8Qޔ8_әapʺ?[{8`KMAY xUeI ,Jk"#:pq"_m,J\񑵇^u{LH۬("4΢2M_o][n̛ٙz ms?p@_{5Q/Yd)iRpR`%:zѹ}Ew1{qQ-E08`@P 3VH;jrბȚumPXQ( 8E`d3$ 3V~נ 1L{Y֋ExhB}ބY!YJ 'ux_$Tz =-Wp8S[B-mBO±nû+j p m 0U2L;B8 $fuU9]ưQaq>,%)6rHpܾMN Hz)~BVci$7}:`"nq7YP~N/(8-/QgR7V,4̐'`7,f hډ܁Zg|}wn^01㓠Ih^M~ΰpXv9ox|+8,it`rx΁tJ?(uwN^W81cn87ްI'i/@ÿӵa%;Lb+AM\./4p3)ӱ)˾pëI!2 V$&+| 0ys:l'v,HH #S@`~)M†tʜp ̚[87bv@I@`ӅhtB0l~U| C*`PvH K)+j Dbɓِ3.=#-uB6#?m5Tp(RgD($'OMEmvYb6nMVqɲ m;~PD824jBs`iٿ14[. VGU $$Rhh/e HH>4;%nyM-{?ka+IENDB`eagerpy-0.30.0/docs/README.md000066400000000000000000000035041410374365400154460ustar00rootroot00000000000000--- home: true heroImage: /logo.png heroText: EagerPy tagline: Writing Code That Works Natively with PyTorch, TensorFlow, JAX, and NumPy actionText: Get Started → actionLink: /guide/ features: - title: Native Performance details: EagerPy operations get directly translated into the corresponding native operations. - title: Fully Chainable details: All functionality is available as methods on the tensor objects and as EagerPy functions. - title: Type Checking details: Catch bugs before running your code thanks to EagerPy's extensive type annotations. footer: Copyright © 2020 Jonas Rauber --- ### What is EagerPy? **EagerPy** is a **Python framework** that lets you write code that automatically works natively with [**PyTorch**](https://pytorch.org), [**TensorFlow**](https://www.tensorflow.org), [**JAX**](https://github.com/google/jax), and [**NumPy**](https://numpy.org). ```python import eagerpy as ep def norm(x): x = ep.astensor(x) result = x.square().sum().sqrt() return result.raw ``` You can now **use** the `norm` function **with native tensors** and arrays from PyTorch, TensorFlow, JAX and NumPy with **virtually no overhead compared to native code**. Of course, it also works with **GPU tensors**. ```python import torch norm(torch.tensor([1., 2., 3.])) # tensor(3.7417) ``` ```python import tensorflow as tf norm(tf.constant([1., 2., 3.])) # ``` ```python import jax.numpy as np norm(np.array([1., 2., 3.])) # DeviceArray(3.7416575, dtype=float32) ``` ```python import numpy as np norm(np.array([1., 2., 3.])) # 3.7416573867739413 ``` ### Getting Started You can install the latest release from [PyPI](https://pypi.org/project/eagerpy/) using `pip`: ```bash python3 -m pip install eagerpy ``` ::: warning NOTE EagerPy requires Python 3.6 or newer. ::: eagerpy-0.30.0/docs/api/000077500000000000000000000000001410374365400147365ustar00rootroot00000000000000eagerpy-0.30.0/docs/api/README.md000066400000000000000000000000671410374365400162200ustar00rootroot00000000000000# Overview * `ep.*` functions * `ep.Tensor.*` methods eagerpy-0.30.0/docs/api/lib.md000066400000000000000000000001041410374365400160210ustar00rootroot00000000000000# eagerpy.* <<< @/../eagerpy/framework.py <<< @/../eagerpy/lib.py eagerpy-0.30.0/docs/api/norms.md000066400000000000000000000012261410374365400164170ustar00rootroot00000000000000# eagerpy.norms ## l0 ```python l0(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## l1 ```python l1(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## l2 ```python l2(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## linf ```python linf(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## lp ```python lp(x:~TensorType, p:Union[int, float], axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` eagerpy-0.30.0/docs/api/tensor.md000066400000000000000000000253101410374365400165730ustar00rootroot00000000000000# PyTorchTensor ```python PyTorchTensor(self, raw:'torch.Tensor') ``` # TensorFlowTensor ```python TensorFlowTensor(self, raw:'tf.Tensor') ``` # JAXTensor ```python JAXTensor(self, raw:'np.ndarray') ``` # NumPyTensor ```python NumPyTensor(self, raw:'np.ndarray') ``` # Tensor ```python Tensor(self, raw:Any) ``` Base class defining the common interface of all EagerPy Tensors ## sign ```python Tensor.sign(self:~TensorType) -> ~TensorType ``` ## sqrt ```python Tensor.sqrt(self:~TensorType) -> ~TensorType ``` ## tanh ```python Tensor.tanh(self:~TensorType) -> ~TensorType ``` ## float32 ```python Tensor.float32(self:~TensorType) -> ~TensorType ``` ## where ```python Tensor.where(self:~TensorType, x:Union[_ForwardRef('Tensor'), int, float], y:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## matmul ```python Tensor.matmul(self:~TensorType, other:~TensorType) -> ~TensorType ``` ## ndim ## numpy ```python Tensor.numpy(self:~TensorType) -> Any ``` ## item ```python Tensor.item(self:~TensorType) -> Union[int, float] ``` ## shape ## reshape ```python Tensor.reshape(self:~TensorType, shape:Union[Tuple[int, ...], int]) -> ~TensorType ``` ## take_along_axis ```python Tensor.take_along_axis(self:~TensorType, index:~TensorType, axis:int) -> ~TensorType ``` ## astype ```python Tensor.astype(self:~TensorType, dtype:Any) -> ~TensorType ``` ## clip ```python Tensor.clip(self:~TensorType, min_:float, max_:float) -> ~TensorType ``` ## square ```python Tensor.square(self:~TensorType) -> ~TensorType ``` ## arctanh ```python Tensor.arctanh(self:~TensorType) -> ~TensorType ``` ## sum ```python Tensor.sum(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## prod ```python Tensor.prod(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## mean ```python Tensor.mean(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## min ```python Tensor.min(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## max ```python Tensor.max(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## minimum ```python Tensor.minimum(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## maximum ```python Tensor.maximum(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## argmin ```python Tensor.argmin(self:~TensorType, axis:Union[int, NoneType]=None) -> ~TensorType ``` ## argmax ```python Tensor.argmax(self:~TensorType, axis:Union[int, NoneType]=None) -> ~TensorType ``` ## argsort ```python Tensor.argsort(self:~TensorType, axis:int=-1) -> ~TensorType ``` ## topk ```python Tensor.topk(self:~TensorType, k:int, sorted:bool=True) -> Tuple[~TensorType, ~TensorType] ``` ## uniform ```python Tensor.uniform(self:~TensorType, shape:Union[Tuple[int, ...], int], low:float=0.0, high:float=1.0) -> ~TensorType ``` ## normal ```python Tensor.normal(self:~TensorType, shape:Union[Tuple[int, ...], int], mean:float=0.0, stddev:float=1.0) -> ~TensorType ``` ## ones ```python Tensor.ones(self:~TensorType, shape:Union[Tuple[int, ...], int]) -> ~TensorType ``` ## zeros ```python Tensor.zeros(self:~TensorType, shape:Union[Tuple[int, ...], int]) -> ~TensorType ``` ## ones_like ```python Tensor.ones_like(self:~TensorType) -> ~TensorType ``` ## zeros_like ```python Tensor.zeros_like(self:~TensorType) -> ~TensorType ``` ## full_like ```python Tensor.full_like(self:~TensorType, fill_value:float) -> ~TensorType ``` ## onehot_like ```python Tensor.onehot_like(self:~TensorType, indices:~TensorType, *, value:float=1) -> ~TensorType ``` ## from_numpy ```python Tensor.from_numpy(self:~TensorType, a:Any) -> ~TensorType ``` ## transpose ```python Tensor.transpose(self:~TensorType, axes:Union[Tuple[int, ...], NoneType]=None) -> ~TensorType ``` ## bool ```python Tensor.bool(self:~TensorType) -> ~TensorType ``` ## all ```python Tensor.all(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## any ```python Tensor.any(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## logical_and ```python Tensor.logical_and(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## logical_or ```python Tensor.logical_or(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## logical_not ```python Tensor.logical_not(self:~TensorType) -> ~TensorType ``` ## exp ```python Tensor.exp(self:~TensorType) -> ~TensorType ``` ## log ```python Tensor.log(self:~TensorType) -> ~TensorType ``` ## log2 ```python Tensor.log2(self:~TensorType) -> ~TensorType ``` ## log10 ```python Tensor.log10(self:~TensorType) -> ~TensorType ``` ## log1p ```python Tensor.log1p(self:~TensorType) -> ~TensorType ``` ## tile ```python Tensor.tile(self:~TensorType, multiples:Tuple[int, ...]) -> ~TensorType ``` ## softmax ```python Tensor.softmax(self:~TensorType, axis:int=-1) -> ~TensorType ``` ## log_softmax ```python Tensor.log_softmax(self:~TensorType, axis:int=-1) -> ~TensorType ``` ## squeeze ```python Tensor.squeeze(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None) -> ~TensorType ``` ## expand_dims ```python Tensor.expand_dims(self:~TensorType, axis:int) -> ~TensorType ``` ## full ```python Tensor.full(self:~TensorType, shape:Union[Tuple[int, ...], int], value:float) -> ~TensorType ``` ## index_update ```python Tensor.index_update(self:~TensorType, indices:Any, values:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## arange ```python Tensor.arange(self:~TensorType, start:int, stop:Union[int, NoneType]=None, step:Union[int, NoneType]=None) -> ~TensorType ``` ## cumsum ```python Tensor.cumsum(self:~TensorType, axis:Union[int, NoneType]=None) -> ~TensorType ``` ## flip ```python Tensor.flip(self:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None) -> ~TensorType ``` ## meshgrid ```python Tensor.meshgrid(self:~TensorType, *tensors:~TensorType, indexing:str='xy') -> Tuple[~TensorType, ...] ``` ## pad ```python Tensor.pad(self:~TensorType, paddings:Tuple[Tuple[int, int], ...], mode:str='constant', value:float=0) -> ~TensorType ``` ## isnan ```python Tensor.isnan(self:~TensorType) -> ~TensorType ``` ## isinf ```python Tensor.isinf(self:~TensorType) -> ~TensorType ``` ## crossentropy ```python Tensor.crossentropy(self:~TensorType, labels:~TensorType) -> ~TensorType ``` ## T ## abs ```python Tensor.abs(self:~TensorType) -> ~TensorType ``` ## pow ```python Tensor.pow(self:~TensorType, exponent:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## value_and_grad ```python Tensor.value_and_grad(self:~TensorType, f:Callable[..., ~TensorType], *args:Any, **kwargs:Any) -> Tuple[~TensorType, ~TensorType] ``` ## value_aux_and_grad ```python Tensor.value_aux_and_grad(self:~TensorType, f:Callable[..., Tuple[~TensorType, Any]], *args:Any, **kwargs:Any) -> Tuple[~TensorType, Any, ~TensorType] ``` ## flatten ```python Tensor.flatten(self:~TensorType, start:int=0, end:int=-1) -> ~TensorType ``` ## l0 ```python NormsMethods.l0(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## l1 ```python NormsMethods.l1(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## l2 ```python NormsMethods.l2(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## linf ```python NormsMethods.linf(x:~TensorType, axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## lp ```python NormsMethods.lp(x:~TensorType, p:Union[int, float], axis:Union[int, Tuple[int, ...], NoneType]=None, keepdims:bool=False) -> ~TensorType ``` ## raw ## dtype ## __init__ ```python Tensor.__init__(self, raw:Any) ``` ## __repr__ ```python Tensor.__repr__(self:~TensorType) -> str ``` ## __format__ ```python Tensor.__format__(self:~TensorType, format_spec:str) -> str ``` ## __getitem__ ```python Tensor.__getitem__(self:~TensorType, index:Any) -> ~TensorType ``` ## __iter__ ```python Tensor.__iter__(self:~TensorType) -> Iterator[~TensorType] ``` ## __bool__ ```python Tensor.__bool__(self:~TensorType) -> bool ``` ## __len__ ```python Tensor.__len__(self:~TensorType) -> int ``` ## __abs__ ```python Tensor.__abs__(self:~TensorType) -> ~TensorType ``` ## __neg__ ```python Tensor.__neg__(self:~TensorType) -> ~TensorType ``` ## __add__ ```python Tensor.__add__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __radd__ ```python Tensor.__radd__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __sub__ ```python Tensor.__sub__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __rsub__ ```python Tensor.__rsub__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __mul__ ```python Tensor.__mul__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __rmul__ ```python Tensor.__rmul__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __truediv__ ```python Tensor.__truediv__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __rtruediv__ ```python Tensor.__rtruediv__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __floordiv__ ```python Tensor.__floordiv__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __rfloordiv__ ```python Tensor.__rfloordiv__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __mod__ ```python Tensor.__mod__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __lt__ ```python Tensor.__lt__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __le__ ```python Tensor.__le__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __eq__ ```python Tensor.__eq__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __ne__ ```python Tensor.__ne__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __gt__ ```python Tensor.__gt__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __ge__ ```python Tensor.__ge__(self:~TensorType, other:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` ## __pow__ ```python Tensor.__pow__(self:~TensorType, exponent:Union[_ForwardRef('Tensor'), int, float]) -> ~TensorType ``` eagerpy-0.30.0/docs/api/types.md000066400000000000000000000000531410374365400164220ustar00rootroot00000000000000# eagerpy.types <<< @/../eagerpy/types.py eagerpy-0.30.0/docs/guide/000077500000000000000000000000001410374365400152625ustar00rootroot00000000000000eagerpy-0.30.0/docs/guide/README.md000066400000000000000000000004531410374365400165430ustar00rootroot00000000000000# Introduction ## What is EagerPy? **EagerPy** is a **Python framework** that lets you write code that automatically works natively with [**PyTorch**](https://pytorch.org), [**TensorFlow**](https://www.tensorflow.org), [**JAX**](https://github.com/google/jax), and [**NumPy**](https://numpy.org). eagerpy-0.30.0/docs/guide/autodiff.md000066400000000000000000000020271410374365400174060ustar00rootroot00000000000000--- title: Automatic Differentiation --- # Automatic Differentiation in EagerPy EagerPy uses a functional approach to automatic differentiation. You first define a function that will then be differentiated with respect to its inputs. This function is then passed to `ep.value_and_grad` to evaluate both the function and its gradient. More generally, you can also use `ep.value_aux_and_grad` if your function has additional auxiliary outputs and `ep.value_and_grad_fn` if you want the gradient function without immediately evaluating it at some point `x`. Using `ep.value_and_grad` for automatic differentiation in EagerPy: ```python import torch x = torch.tensor([1., 2., 3.]) # The following code works for any framework, not just Pytorch! import eagerpy as ep x = ep.astensor(x) def loss_fn(x): # this function takes and returns an EagerPy tensor return x.square().sum() print(loss_fn(x)) # PyTorchTensor(tensor(14.)) print(ep.value_and_grad(loss_fn, x)) # (PyTorchTensor(tensor(14.)), PyTorchTensor(tensor([2., 4., 6.]))) ``` eagerpy-0.30.0/docs/guide/citation.md000066400000000000000000000006651410374365400174250ustar00rootroot00000000000000# Citation If you use EagerPy, please cite our [paper](https://arxiv.org/abs/2008.04175) using the this BibTex entry: ```bibtex @article{rauber2020eagerpy, title={{EagerPy}: Writing Code That Works Natively with {PyTorch}, {TensorFlow}, {JAX}, and {NumPy}}, author={Rauber, Jonas and Bethge, Matthias and Brendel, Wieland}, journal={arXiv preprint arXiv:2008.04175}, year={2020}, url={https://eagerpy.jonasrauber.de}, } ``` eagerpy-0.30.0/docs/guide/converting.md000066400000000000000000000033221410374365400177620ustar00rootroot00000000000000--- title: Converting --- # Converting Between EagerPy and Native Tensors A native tensor could be a PyTorch GPU or CPU tensor, a TensorFlow tensor, a JAX array, or a NumPy array. **A native PyTorch tensor:** ```python import torch x = torch.tensor([1., 2., 3., 4., 5., 6.]) ``` **A native TensorFlow tensor:** ```python import tensorflow as tf x = tf.constant([1., 2., 3., 4., 5., 6.]) ``` **A native JAX array:** ```python import jax.numpy as np x = np.array([1., 2., 3., 4., 5., 6.]) ``` **A native NumPy array:** ```python import numpy as np x = np.array([1., 2., 3., 4., 5., 6.]) ``` No matter which native tensor you have, it can always be turned into the appropriate EagerPy tensor using `ep.astensor`. This will automatically wrap the native tensor with the correct EagerPy tensor class. The original native tensor can always be accessed using the `.raw` attribute. ```python # x should be a native tensor (see above) # for example: import torch x = torch.tensor([1., 2., 3., 4., 5., 6.]) # Any native tensor can easily be turned into an EagerPy tensor import eagerpy as ep x = ep.astensor(x) # Now we can perform any EagerPy operation x = x.square() # And convert the EagerPy tensor back into a native tensor x = x.raw # x will now again be a native tensor (e.g. a PyTorch tensor) ``` Especially in functions, it is common to convert all inputs to EagerPy tensors. This could be done using individual calls to `ep.astensor`, but using `ep.astensors` this can be written even more compactly. ```python # x, y should be a native tensors (see above) # for example: import torch x = torch.tensor([1., 2., 3.]) y = torch.tensor([4., 5., 6.]) import eagerpy as ep x, y = ep.astensors(x, y) # works for any number of inputs ``` eagerpy-0.30.0/docs/guide/development.md000066400000000000000000000050121410374365400201240ustar00rootroot00000000000000# Development ::: tip NOTE The following is only necessary if you want to contribute features to EagerPy. As a user of EagerPy, you can just do a normal [installation](./getting-started). ::: ## Installation First clone the repsository using `git`: ```bash git clone https://github.com/jonasrauber/eagerpy ``` You can then do an editable installation using `pip -e`: ```bash cd eagerpy pip3 install -e . ``` ::: tip Create a new branch for each new feature or contribution. This will be necessary to open a pull request later. ::: ## Coding Style We follow the [PEP 8 Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/). We use [black](https://github.com/psf/black) for automatic code formatting. In addition, we use [flake8](https://flake8.pycqa.org/en/latest/) to detect certain PEP 8 violations. ::: tip Have a look at the `Makefile`. It contains many useful commands, e.g. `make black` or `make flake8`. ::: ## Type annotions and MyPy EagerPy uses Python type annotations introduced in [PEP 484](https://www.python.org/dev/peps/pep-0484/). We use [mypy](http://mypy-lang.org) for static type checking with relatively strict settings. All code in EagerPy has to be type annotated. We recommend to run MyPy or a comparable type checker automatically in your editor (e.g. VIM) or IDE (e.g. PyCharm). You can also run MyPy from the command line: ```bash make mypy # run this in the root folder that contains the Makefile ``` ::: tip NOTE `__init__` methods should not have return type annotations unless they have no type annotated arguments (i.e. only `self`), in which case the return type of `__init__` should be specifed as `None`. ::: ## Creating a pull request on GitHub First, fork the [EagerPy repository on GitHub](https://github.com/jonasrauber/eagerpy). Then, add the fork to your local GitHub repository: ```bash git remote add fork https://github.com/YOUR USERNAME/eagerpy ``` Finally, push your new branch to GitHub and open a pull request. ## Release Process EagerPy currently follows a rapid release process. Whenever non-trivial changes have been made, the documentation and tests have been updated, and all tests pass, a new version can be released. To reduce the barrier, this is simply done by creating a new release on GitHub. This automatically triggers a [GitHub Action](https://github.com/jonasrauber/eagerpy/actions) that builds the package and publishes the new version on the Python Package Index [PyPI](https://pypi.org/project/eagerpy/). The latest version can thus be simply installed using `pip`: eagerpy-0.30.0/docs/guide/examples.md000066400000000000000000000017111410374365400174220ustar00rootroot00000000000000--- title: Examples --- # Examples :tada: ## A framework-agnostic `norm` function Write your function using EagerPy: ```python import eagerpy as ep def norm(x): x = ep.astensor(x) result = x.square().sum().sqrt() return result.raw ``` You can now **use** the `norm` function **with native tensors** and arrays from PyTorch, TensorFlow, JAX and NumPy with **virtually no overhead compared to native code**. Of course, it also works with **GPU tensors**. ```python import torch norm(torch.tensor([1., 2., 3.])) # tensor(3.7417) ``` ```python import tensorflow as tf norm(tf.constant([1., 2., 3.])) # ``` ```python import jax.numpy as np norm(np.array([1., 2., 3.])) # DeviceArray(3.7416575, dtype=float32) ``` ```python import numpy as np norm(np.array([1., 2., 3.])) # 3.7416573867739413 ``` ::: tip NOTE EagerPy already comes with a [builtin implementation of `norm`](/api/norms.md#l2). ::: eagerpy-0.30.0/docs/guide/generic-functions.md000066400000000000000000000044101410374365400212250ustar00rootroot00000000000000--- title: Generic Functions --- # Implementing Generic Framework-Agnostic Functions Using the conversion functions shown in [Converting](./converting.md), we can already define a simple framework-agnostic function. ```python import eagerpy as ep def norm(x): x = ep.astensor(x) result = x.square().sum().sqrt() return result.raw ``` This function can be called with a native tensor from any framework and it will return the norm of that tensor, again as a native tensor from that framework. Calling the `norm` function using a PyTorch tensor: ```python import torch norm(torch.tensor([1., 2., 3.])) # tensor(3.7417) ``` Calling the `norm` function using a TensorFlow tensor: ```python import tensorflow as tf norm(tf.constant([1., 2., 3.])) # ``` If we would call the above `norm` function with an EagerPy tensor, the `ep.astensor` call would simply return its input. The `result.raw` call in the last line would however still extract the underlying native tensor. Often it is preferably to implement a generic function that not only transparently handles any native tensor but also EagerPy tensors, that is the return type should always match the input type. This is particularly useful in libraries like Foolbox that allow users to work with EagerPy and native tensors. To achieve that, EagerPy comes with two derivatives of the above conversion functions: `ep.astensor_` and `ep.astensors_`. Unlike their counterparts without an underscore, they return an additional inversion function that restores the input type. If the input to `astensor_` is a native tensor, `restore_type` will be identical to `.raw`, but if the original input was an EagerPy tensor, `restore_type` will not call `.raw`. With that, we can write generic framework-agnostic functions that work transparently for any input. An improved framework-agnostic `norm` function: ```python import eagerpy as ep def norm(x): x, restore_type = ep.astensor_(x) result = x.square().sum().sqrt() return restore_type(result) ``` Converting and restoring multiple inputs using `ep.astensors_`: ```python import eagerpy as ep def example(x, y, z): (x, y, z), restore_type = ep.astensors_(x, y, z) result = (x + y) * z return restore_type(result) ``` eagerpy-0.30.0/docs/guide/getting-started.md000066400000000000000000000013111410374365400207050ustar00rootroot00000000000000# Getting Started ## Installation You can install the latest release from [PyPI](https://pypi.org/project/eagerpy/) using `pip`: ```bash python3 -m pip install eagerpy ``` EagerPy requires Python 3.6 or newer. Besides that, all essential dependencies are automatically installed. To use it with PyTorch, TensorFlow, JAX, or NumPy, the respective framework needs to be installed separately. These frameworks are not declared as dependencies because not everyone wants to use and thus install all of them and because some of these packages have different builds for different architectures and [CUDA](https://developer.nvidia.com/cuda-zone) versions. ::: warning NOTE EagerPy requires Python 3.6 or newer. ::: eagerpy-0.30.0/eagerpy/000077500000000000000000000000001410374365400146715ustar00rootroot00000000000000eagerpy-0.30.0/eagerpy/VERSION000066400000000000000000000000071410374365400157360ustar00rootroot000000000000000.30.0 eagerpy-0.30.0/eagerpy/__init__.py000066400000000000000000000024271410374365400170070ustar00rootroot00000000000000from typing import TypeVar from os.path import join as _join from os.path import dirname as _dirname with open(_join(_dirname(__file__), "VERSION")) as _f: __version__ = _f.read().strip() _T = TypeVar("_T") class _Indexable: __slots__ = () def __getitem__(self, index: _T) -> _T: return index index = _Indexable() from .tensor import Tensor # noqa: F401,E402 from .tensor import TensorType # noqa: F401,E402 from .tensor import istensor # noqa: F401,E402 from .tensor import PyTorchTensor # noqa: F401,E402 from .tensor import TensorFlowTensor # noqa: F401,E402 from .tensor import NumPyTensor # noqa: F401,E402 from .tensor import JAXTensor # noqa: F401,E402 from . import types # noqa: F401,E402 from .astensor import astensor # noqa: F401,E402 from .astensor import astensors # noqa: F401,E402 from .astensor import astensor_ # noqa: F401,E402 from .astensor import astensors_ # noqa: F401,E402 from .modules import torch # noqa: F401,E402 from .modules import tensorflow # noqa: F401,E402 from .modules import jax # noqa: F401,E402 from .modules import numpy # noqa: F401,E402 from . import utils # noqa: F401,E402 from .framework import * # noqa: F401,E402,F403 from . import norms # noqa: F401,E402 from .lib import * # noqa: F401,E402,F403 eagerpy-0.30.0/eagerpy/astensor.py000066400000000000000000000051661410374365400171110ustar00rootroot00000000000000from typing import TYPE_CHECKING, Union, overload, Tuple, TypeVar, Generic, Any import sys from .tensor import Tensor from .tensor import TensorType from .tensor import PyTorchTensor from .tensor import TensorFlowTensor from .tensor import JAXTensor from .tensor import NumPyTensor from .types import NativeTensor if TYPE_CHECKING: # for static analyzers import torch def _get_module_name(x: Any) -> str: # splitting is necessary for TensorFlow tensors return type(x).__module__.split(".")[0] @overload def astensor(x: TensorType) -> TensorType: ... @overload def astensor(x: "torch.Tensor") -> PyTorchTensor: ... @overload def astensor(x: NativeTensor) -> Tensor: # type: ignore ... def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore if isinstance(x, Tensor): return x # we use the module name instead of isinstance # to avoid importing all the frameworks name = _get_module_name(x) m = sys.modules if name == "torch" and isinstance(x, m[name].Tensor): # type: ignore return PyTorchTensor(x) if name == "tensorflow" and isinstance(x, m[name].Tensor): # type: ignore return TensorFlowTensor(x) if (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray): # type: ignore return JAXTensor(x) if name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore return NumPyTensor(x) raise ValueError(f"Unknown type: {type(x)}") def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]: # type: ignore return tuple(astensor(x) for x in xs) T = TypeVar("T") class RestoreTypeFunc(Generic[T]): def __init__(self, x: T): self.unwrap = not isinstance(x, Tensor) @overload def __call__(self, x: Tensor) -> T: ... @overload # noqa: F811 def __call__(self, x: Tensor, y: Tensor) -> Tuple[T, T]: ... @overload # noqa: F811 def __call__(self, x: Tensor, y: Tensor, z: Tensor, *args: Tensor) -> Tuple[T, ...]: ... @overload # noqa: F811 def __call__(self, *args: Any) -> Any: # catch other types, otherwise we would return type T for input type Any ... def __call__(self, *args): # type: ignore # noqa: F811 result = tuple(x.raw for x in args) if self.unwrap else args if len(result) == 1: (result,) = result return result def astensor_(x: T) -> Tuple[Tensor, RestoreTypeFunc[T]]: return astensor(x), RestoreTypeFunc[T](x) def astensors_(x: T, *xs: T) -> Tuple[Tuple[Tensor, ...], RestoreTypeFunc[T]]: return astensors(x, *xs), RestoreTypeFunc[T](x) eagerpy-0.30.0/eagerpy/framework.py000066400000000000000000000211741410374365400172450ustar00rootroot00000000000000from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast, Union from typing_extensions import Literal from .types import Axes, AxisAxes, Shape, ShapeOrScalar from .tensor import Tensor from .tensor import TensorType from .tensor import TensorOrScalar newaxis = None inf = float("inf") nan = float("nan") def clip(t: TensorType, min_: float, max_: float) -> TensorType: return t.clip(min_, max_) def abs(t: TensorType) -> TensorType: return t.abs() def sign(t: TensorType) -> TensorType: return t.sign() def sqrt(t: TensorType) -> TensorType: return t.sqrt() def square(t: TensorType) -> TensorType: return t.square() def pow(t: TensorType, exponent: TensorOrScalar) -> TensorType: return t.pow(exponent) def tanh(t: TensorType) -> TensorType: return t.tanh() def arctanh(t: TensorType) -> TensorType: return t.arctanh() def sum( t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return t.sum(axis=axis, keepdims=keepdims) def prod( t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return t.prod(axis=axis, keepdims=keepdims) def mean( t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return t.mean(axis=axis, keepdims=keepdims) def min( t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return t.min(axis=axis, keepdims=keepdims) def max( t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return t.max(axis=axis, keepdims=keepdims) @overload def minimum(x: TensorType, y: TensorOrScalar) -> TensorType: ... @overload def minimum(x: TensorOrScalar, y: TensorType) -> TensorType: ... def minimum(x: TensorOrScalar, y: TensorOrScalar) -> Tensor: if not isinstance(x, Tensor): return cast(Tensor, y).minimum(x) return x.minimum(y) @overload def maximum(x: TensorType, y: TensorOrScalar) -> TensorType: ... @overload def maximum(x: TensorOrScalar, y: TensorType) -> TensorType: ... def maximum(x: TensorOrScalar, y: TensorOrScalar) -> Tensor: if not isinstance(x, Tensor): return cast(Tensor, y).maximum(x) return x.maximum(y) def argmin(t: TensorType, axis: Optional[int] = None) -> TensorType: return t.argmin(axis=axis) def argmax(t: TensorType, axis: Optional[int] = None) -> TensorType: return t.argmax(axis=axis) def argsort(t: TensorType, axis: int = -1) -> TensorType: return t.argsort(axis=axis) def sort(t: TensorType, axis: int = -1) -> TensorType: return t.sort(axis=axis) def topk(t: TensorType, k: int, sorted: bool = True) -> Tuple[TensorType, TensorType]: return t.topk(k, sorted=sorted) def uniform( t: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0 ) -> TensorType: return t.uniform(shape, low=low, high=high) def normal( t: TensorType, shape: ShapeOrScalar, mean: float = 0.0, stddev: float = 1.0 ) -> TensorType: return t.normal(shape, mean=mean, stddev=stddev) def ones(t: TensorType, shape: ShapeOrScalar) -> TensorType: return t.ones(shape) def zeros(t: TensorType, shape: ShapeOrScalar) -> TensorType: return t.zeros(shape) def ones_like(t: TensorType) -> TensorType: return t.ones_like() def zeros_like(t: TensorType) -> TensorType: return t.zeros_like() def full_like(t: TensorType, fill_value: float) -> TensorType: return t.full_like(fill_value) def onehot_like(t: TensorType, indices: TensorType, *, value: float = 1) -> TensorType: return t.onehot_like(indices, value=value) def from_numpy(t: TensorType, a: Any) -> TensorType: return t.from_numpy(a) def concatenate(tensors: Sequence[TensorType], axis: int = 0) -> TensorType: t = tensors[0] return t._concatenate(tensors, axis=axis) def transpose(t: TensorType, axes: Optional[Axes] = None) -> TensorType: return t.transpose(axes=axes) @overload def logical_and(x: TensorType, y: TensorOrScalar) -> TensorType: ... @overload def logical_and(x: TensorOrScalar, y: TensorType) -> TensorType: ... def logical_and(x: TensorOrScalar, y: TensorOrScalar) -> Tensor: if not isinstance(x, Tensor): return cast(Tensor, y).logical_and(x) return x.logical_and(y) @overload def logical_or(x: TensorType, y: TensorOrScalar) -> TensorType: ... @overload def logical_or(x: TensorOrScalar, y: TensorType) -> TensorType: ... def logical_or(x: TensorOrScalar, y: TensorOrScalar) -> Tensor: if not isinstance(x, Tensor): return cast(Tensor, y).logical_or(x) return x.logical_or(y) def logical_not(t: TensorType) -> TensorType: return t.logical_not() def exp(t: TensorType) -> TensorType: return t.exp() def log(t: TensorType) -> TensorType: return t.log() def log2(t: TensorType) -> TensorType: return t.log2() def log10(t: TensorType) -> TensorType: return t.log10() def log1p(t: TensorType) -> TensorType: return t.log1p() def where(condition: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType: return condition.where(x, y) def tile(t: TensorType, multiples: Axes) -> TensorType: return t.tile(multiples) def matmul(x: TensorType, y: TensorType) -> TensorType: return x.matmul(y) def softmax(t: TensorType, axis: int = -1) -> TensorType: return t.softmax(axis=axis) def log_softmax(t: TensorType, axis: int = -1) -> TensorType: return t.log_softmax(axis=axis) def stack(tensors: Sequence[TensorType], axis: int = 0) -> TensorType: t = tensors[0] return t._stack(tensors, axis=axis) def squeeze(t: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: return t.squeeze(axis=axis) def expand_dims(t: TensorType, axis: int) -> TensorType: return t.expand_dims(axis=axis) def full(t: TensorType, shape: ShapeOrScalar, value: float) -> TensorType: return t.full(shape, value) def index_update(t: TensorType, indices: Any, values: TensorOrScalar) -> TensorType: return t.index_update(indices, values) def arange( t: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None ) -> TensorType: return t.arange(start, stop, step) def cumsum(t: TensorType, axis: Optional[int] = None) -> TensorType: return t.cumsum(axis=axis) def flip(t: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: return t.flip(axis=axis) def meshgrid( t: TensorType, *tensors: TensorType, indexing: str = "xy" ) -> Tuple[TensorType, ...]: return t.meshgrid(*tensors, indexing=indexing) def pad( t: TensorType, paddings: Tuple[Tuple[int, int], ...], mode: str = "constant", value: float = 0, ) -> TensorType: return t.pad(paddings, mode=mode, value=value) def isnan(t: TensorType) -> TensorType: return t.isnan() def isinf(t: TensorType) -> TensorType: return t.isinf() def all( t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return t.all(axis=axis, keepdims=keepdims) def any( t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return t.any(axis=axis, keepdims=keepdims) def crossentropy(logits: TensorType, labels: TensorType) -> TensorType: return logits.crossentropy(labels) def slogdet(matrix: TensorType) -> Tuple[TensorType, TensorType]: return matrix.slogdet() @overload def value_and_grad_fn( t: TensorType, f: Callable[..., TensorType] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload def value_and_grad_fn( t: TensorType, f: Callable[..., TensorType], has_aux: Literal[False] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload def value_and_grad_fn( t: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True] ) -> Callable[..., Tuple[TensorType, Any, TensorType]]: ... def value_and_grad_fn(t: Any, f: Any, has_aux: bool = False) -> Any: return t._value_and_grad_fn(f, has_aux=has_aux) def value_and_grad( f: Callable[..., TensorType], t: TensorType, *args: Any, **kwargs: Any ) -> Tuple[TensorType, TensorType]: return t.value_and_grad(f, *args, **kwargs) def value_aux_and_grad( f: Callable[..., Tuple[TensorType, Any]], t: TensorType, *args: Any, **kwargs: Any ) -> Tuple[TensorType, Any, TensorType]: return t.value_aux_and_grad(f, *args, **kwargs) def reshape(t: TensorType, shape: Union[Shape, int]) -> TensorType: return t.reshape(shape) def take_along_axis(t: TensorType, indices: TensorType, axis: int) -> TensorType: return t.take_along_axis(indices, axis) def flatten(t: TensorType, start: int = 0, end: int = -1) -> TensorType: return t.flatten(start=start, end=end) eagerpy-0.30.0/eagerpy/lib.py000066400000000000000000000005361410374365400160150ustar00rootroot00000000000000from .tensor import TensorType def kl_div_with_logits( logits_p: TensorType, logits_q: TensorType, axis: int = -1, keepdims: bool = False ) -> TensorType: log_p = logits_p.log_softmax(axis=axis) log_q = logits_q.log_softmax(axis=axis) p = logits_p.softmax(axis=-1) return (p * (log_p - log_q)).sum(axis=axis, keepdims=keepdims) eagerpy-0.30.0/eagerpy/modules.py000066400000000000000000000025241410374365400167160ustar00rootroot00000000000000from importlib import import_module import inspect from types import ModuleType from typing import Any, Callable, Iterable import functools from .astensor import astensor def wrap(f: Callable) -> Callable: @functools.wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Any: result = f(*args, **kwargs) try: result = astensor(result) except ValueError: pass return result return wrapper class ModuleWrapper(ModuleType): """A wrapper for modules that delays the import until it is needed and wraps the output of functions as EagerPy tensors""" def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) if self.__doc__ is None: self.__doc__ = f"EagerPy wrapper of the '{self.__name__}' module" def __dir__(self) -> Iterable[str]: # makes sure tab completion works return import_module(self.__name__).__dir__() def __getattr__(self, name: str) -> Any: attr = getattr(import_module(self.__name__), name) if callable(attr): attr = wrap(attr) elif inspect.ismodule(attr): attr = ModuleWrapper(attr.__name__) return attr torch = ModuleWrapper("torch") tensorflow = ModuleWrapper("tensorflow") jax = ModuleWrapper("jax") numpy = ModuleWrapper("numpy") eagerpy-0.30.0/eagerpy/norms.py000066400000000000000000000023411410374365400164010ustar00rootroot00000000000000from typing import Union, Optional from .tensor import TensorType from .types import AxisAxes from .framework import inf def l0( x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return (x != 0).sum(axis=axis, keepdims=keepdims).astype(x.dtype) def l1( x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return x.abs().sum(axis=axis, keepdims=keepdims) def l2( x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return x.square().sum(axis=axis, keepdims=keepdims).sqrt() def linf( x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return x.abs().max(axis=axis, keepdims=keepdims) def lp( x: TensorType, p: Union[int, float], axis: Optional[AxisAxes] = None, keepdims: bool = False, ) -> TensorType: if p == 0: return l0(x, axis=axis, keepdims=keepdims) if p == 1: return l1(x, axis=axis, keepdims=keepdims) if p == 2: return l2(x, axis=axis, keepdims=keepdims) if p == inf: return linf(x, axis=axis, keepdims=keepdims) return x.abs().pow(p).sum(axis=axis, keepdims=keepdims).pow(1.0 / p) eagerpy-0.30.0/eagerpy/py.typed000066400000000000000000000000001410374365400163560ustar00rootroot00000000000000eagerpy-0.30.0/eagerpy/tensor/000077500000000000000000000000001410374365400162035ustar00rootroot00000000000000eagerpy-0.30.0/eagerpy/tensor/__init__.py000066400000000000000000000005611410374365400203160ustar00rootroot00000000000000from .tensor import Tensor # noqa: F401 from .tensor import TensorType # noqa: F401 from .tensor import TensorOrScalar # noqa: F401 from .tensor import istensor # noqa: F401 from .pytorch import PyTorchTensor # noqa: F401 from .tensorflow import TensorFlowTensor # noqa: F401 from .numpy import NumPyTensor # noqa: F401 from .jax import JAXTensor # noqa: F401 eagerpy-0.30.0/eagerpy/tensor/base.py000066400000000000000000000065321410374365400174750ustar00rootroot00000000000000from typing_extensions import final from typing import Any, cast from .tensor import Tensor from .tensor import TensorType from .tensor import TensorOrScalar def unwrap_(*args: Any) -> Any: return tuple(t.raw if isinstance(t, Tensor) else t for t in args) def unwrap1(t: Any) -> Any: return t.raw if isinstance(t, Tensor) else t class BaseTensor(Tensor): __slots__ = "_raw" def __init__(self: TensorType, raw: Any): assert not isinstance(raw, Tensor) self._raw = raw @property def raw(self) -> Any: return self._raw @final def __repr__(self: TensorType) -> str: lines = repr(self.raw).split("\n") prefix = self.__class__.__name__ + "(" lines[0] = prefix + lines[0] prefix = " " * len(prefix) for i in range(1, len(lines)): lines[i] = prefix + lines[i] lines[-1] = lines[-1] + ")" return "\n".join(lines) @final def __format__(self: TensorType, format_spec: str) -> str: return format(self.raw, format_spec) @final @property def dtype(self: TensorType) -> Any: return self.raw.dtype @final def __bool__(self: TensorType) -> bool: return bool(self.raw) @final def __len__(self: TensorType) -> int: return cast(int, self.raw.shape[0]) @final def __abs__(self: TensorType) -> TensorType: return type(self)(abs(self.raw)) @final def __neg__(self: TensorType) -> TensorType: return type(self)(-self.raw) @final def __add__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__add__(unwrap1(other))) @final def __radd__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__radd__(unwrap1(other))) @final def __sub__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__sub__(unwrap1(other))) @final def __rsub__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__rsub__(unwrap1(other))) @final def __mul__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__mul__(unwrap1(other))) @final def __rmul__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__rmul__(unwrap1(other))) @final def __truediv__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__truediv__(unwrap1(other))) @final def __rtruediv__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__rtruediv__(unwrap1(other))) @final def __floordiv__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__floordiv__(unwrap1(other))) @final def __rfloordiv__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__rfloordiv__(unwrap1(other))) @final def __mod__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__mod__(unwrap1(other))) @final def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType: return type(self)(self.raw.__pow__(unwrap1(exponent))) @final @property def ndim(self: TensorType) -> int: return len(self.raw.shape) eagerpy-0.30.0/eagerpy/tensor/extensions.py000066400000000000000000000030341410374365400207540ustar00rootroot00000000000000from typing import TypeVar, Callable, Any, Generic import typing import functools from .. import norms from .tensor import Tensor T = TypeVar("T") def extensionmethod(f: Callable[..., T]) -> Callable[..., T]: @functools.wraps(f) def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return f(self._instance, *args, **kwargs) return wrapper class ExtensionMeta(type): def __new__(cls, name, bases, attrs): # type: ignore if bases != (): # creating a subclass of ExtensionMethods # wrap the attributes with extensionmethod attrs = { k: extensionmethod(v) if not k.startswith("__") else v for k, v in attrs.items() } return super().__new__(cls, name, bases, attrs) if hasattr(typing, "GenericMeta"): # Python 3.6 # workaround for https://github.com/python/typing/issues/449 class GenericExtensionMeta(typing.GenericMeta, ExtensionMeta): pass else: # pragma: no cover # Python 3.7 and newer class GenericExtensionMeta(ExtensionMeta): # type: ignore pass class ExtensionMethods(metaclass=GenericExtensionMeta): def __init__(self, instance: Tensor): self._instance = instance T_co = TypeVar("T_co", bound=Tensor, covariant=True) class NormsMethods(Generic[T_co], ExtensionMethods): l0: Callable[..., T_co] = norms.l0 l1: Callable[..., T_co] = norms.l1 l2: Callable[..., T_co] = norms.l2 linf: Callable[..., T_co] = norms.linf lp: Callable[..., T_co] = norms.lp eagerpy-0.30.0/eagerpy/tensor/jax.py000066400000000000000000000454451410374365400173530ustar00rootroot00000000000000from typing import ( Tuple, cast, Union, Any, TypeVar, TYPE_CHECKING, Iterable, Optional, overload, Callable, Type, ) from typing_extensions import Literal from importlib import import_module import numpy as onp from ..types import Axes, AxisAxes, Shape, ShapeOrScalar from .tensor import Tensor from .tensor import TensorOrScalar from .base import BaseTensor from .base import unwrap_ from .base import unwrap1 if TYPE_CHECKING: # for static analyzers import jax import jax.numpy as np from .extensions import NormsMethods # noqa: F401 else: # lazy import in JAXTensor jax = None np = None # stricter TensorType to support additional internal methods TensorType = TypeVar("TensorType", bound="JAXTensor") def assert_bool(x: Any) -> None: if not isinstance(x, Tensor): return if x.dtype != jax.numpy.bool_: raise ValueError(f"requires dtype bool, got {x.dtype}, consider t.bool().all()") def getitem_preprocess(x: Any) -> Any: if isinstance(x, range): return list(x) elif isinstance(x, Tensor): return x.raw else: return x class JAXTensor(BaseTensor): __slots__ = () # more specific types for the extensions norms: "NormsMethods[JAXTensor]" _registered = False key = None def __new__(cls: Type["JAXTensor"], *args: Any, **kwargs: Any) -> "JAXTensor": if not cls._registered: import jax def flatten(t: JAXTensor) -> Tuple[Any, None]: return ((t.raw,), None) def unflatten(aux_data: None, children: Tuple) -> JAXTensor: return cls(*children) jax.tree_util.register_pytree_node(cls, flatten, unflatten) cls._registered = True return cast(JAXTensor, super().__new__(cls)) def __init__(self, raw: "np.ndarray"): # type: ignore global jax global np if jax is None: jax = import_module("jax") np = import_module("jax.numpy") super().__init__(raw) @property def raw(self) -> "np.ndarray": # type: ignore return super().raw @classmethod def _get_subkey(cls) -> Any: if cls.key is None: cls.key = jax.random.PRNGKey(0) cls.key, subkey = jax.random.split(cls.key) return subkey def numpy(self) -> Any: a = onp.asarray(self.raw) assert a.flags.writeable is False return a def item(self) -> Union[int, float, bool]: return self.raw.item() # type: ignore @property def shape(self) -> Shape: return cast(Tuple, self.raw.shape) def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: if isinstance(shape, int): shape = (shape,) return type(self)(self.raw.reshape(shape)) def astype(self: TensorType, dtype: Any) -> TensorType: return type(self)(self.raw.astype(dtype)) def clip(self: TensorType, min_: float, max_: float) -> TensorType: return type(self)(np.clip(self.raw, min_, max_)) def square(self: TensorType) -> TensorType: return type(self)(np.square(self.raw)) def arctanh(self: TensorType) -> TensorType: return type(self)(np.arctanh(self.raw)) def sum( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.sum(axis=axis, keepdims=keepdims)) def prod( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.prod(axis=axis, keepdims=keepdims)) def mean( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if self.raw.dtype not in [np.float16, np.float32, np.float64]: raise ValueError( f"Can only calculate the mean of floating types. Got {self.raw.dtype} instead." ) return type(self)(self.raw.mean(axis=axis, keepdims=keepdims)) def min( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.min(axis=axis, keepdims=keepdims)) def max( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.max(axis=axis, keepdims=keepdims)) def minimum(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(np.minimum(self.raw, unwrap1(other))) def maximum(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(np.maximum(self.raw, unwrap1(other))) def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.argmin(axis=axis)) def argmax(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.argmax(axis=axis)) def argsort(self: TensorType, axis: int = -1) -> TensorType: return type(self)(self.raw.argsort(axis=axis)) def sort(self: TensorType, axis: int = -1) -> TensorType: return type(self)(self.raw.sort(axis=axis)) def topk( self: TensorType, k: int, sorted: bool = True ) -> Tuple[TensorType, TensorType]: # argpartition not yet implemented # wrapping indexing not supported in take() n = self.raw.shape[-1] idx = np.take(np.argsort(self.raw), np.arange(n - k, n), axis=-1) val = np.take_along_axis(self.raw, idx, axis=-1) if sorted: perm = np.flip(np.argsort(val, axis=-1), axis=-1) idx = np.take_along_axis(idx, perm, axis=-1) val = np.take_along_axis(self.raw, idx, axis=-1) return type(self)(val), type(self)(idx) def uniform( self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0 ) -> TensorType: if not isinstance(shape, Iterable): shape = (shape,) subkey = self._get_subkey() return type(self)(jax.random.uniform(subkey, shape, minval=low, maxval=high)) def normal( self: TensorType, shape: ShapeOrScalar, mean: float = 0.0, stddev: float = 1.0 ) -> TensorType: if not isinstance(shape, Iterable): shape = (shape,) subkey = self._get_subkey() return type(self)(jax.random.normal(subkey, shape) * stddev + mean) def ones(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)(np.ones(shape, dtype=self.raw.dtype)) def zeros(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)(np.zeros(shape, dtype=self.raw.dtype)) def ones_like(self: TensorType) -> TensorType: return type(self)(np.ones_like(self.raw)) def zeros_like(self: TensorType) -> TensorType: return type(self)(np.zeros_like(self.raw)) def full_like(self: TensorType, fill_value: float) -> TensorType: return type(self)(np.full_like(self.raw, fill_value)) def onehot_like( self: TensorType, indices: TensorType, *, value: float = 1 ) -> TensorType: if self.ndim != 2: raise ValueError("onehot_like only supported for 2D tensors") if indices.ndim != 1: raise ValueError("onehot_like requires 1D indices") if len(indices) != len(self): raise ValueError("length of indices must match length of tensor") x = np.arange(self.raw.shape[1]).reshape(1, -1) indices = indices.raw.reshape(-1, 1) return type(self)((x == indices) * value) def from_numpy(self: TensorType, a: Any) -> TensorType: return type(self)(np.asarray(a)) def _concatenate( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # concatenates only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(np.concatenate(tensors_, axis=axis)) def _stack( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # stacks only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(np.stack(tensors_, axis=axis)) def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: if axes is None: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(np.transpose(self.raw, axes=axes)) def all( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) return type(self)(self.raw.all(axis=axis, keepdims=keepdims)) def any( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) return type(self)(self.raw.any(axis=axis, keepdims=keepdims)) def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(np.logical_and(self.raw, unwrap1(other))) def logical_or(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(np.logical_or(self.raw, unwrap1(other))) def logical_not(self: TensorType) -> TensorType: assert_bool(self) return type(self)(np.logical_not(self.raw)) def exp(self: TensorType) -> TensorType: return type(self)(np.exp(self.raw)) def log(self: TensorType) -> TensorType: return type(self)(np.log(self.raw)) def log2(self: TensorType) -> TensorType: return type(self)(np.log2(self.raw)) def log10(self: TensorType) -> TensorType: return type(self)(np.log10(self.raw)) def log1p(self: TensorType) -> TensorType: return type(self)(np.log1p(self.raw)) def tile(self: TensorType, multiples: Axes) -> TensorType: multiples = unwrap1(multiples) if len(multiples) != self.ndim: raise ValueError("multiples requires one entry for each dimension") return type(self)(np.tile(self.raw, multiples)) def softmax(self: TensorType, axis: int = -1) -> TensorType: return type(self)(jax.nn.softmax(self.raw, axis=axis)) def log_softmax(self: TensorType, axis: int = -1) -> TensorType: return type(self)(jax.nn.log_softmax(self.raw, axis=axis)) def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: if axis is not None: # workaround for https://github.com/google/jax/issues/2284 axis = (axis,) if isinstance(axis, int) else axis shape = self.shape if any(shape[i] != 1 for i in axis): raise ValueError( "cannot select an axis to squeeze out which has size not equal to one" ) return type(self)(self.raw.squeeze(axis=axis)) def expand_dims(self: TensorType, axis: int) -> TensorType: return type(self)(np.expand_dims(self.raw, axis=axis)) def full(self: TensorType, shape: ShapeOrScalar, value: float) -> TensorType: if not isinstance(shape, Iterable): shape = (shape,) return type(self)(np.full(shape, value, dtype=self.raw.dtype)) def index_update( self: TensorType, indices: Any, values: TensorOrScalar ) -> TensorType: indices, values = unwrap_(indices, values) if isinstance(indices, tuple): indices = unwrap_(*indices) return type(self)(jax.ops.index_update(self.raw, indices, values)) def arange( self: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None, ) -> TensorType: return type(self)(np.arange(start, stop, step)) def cumsum(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.cumsum(axis=axis)) def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: return type(self)(np.flip(self.raw, axis=axis)) def meshgrid( self: TensorType, *tensors: TensorType, indexing: str = "xy" ) -> Tuple[TensorType, ...]: tensors = unwrap_(*tensors) outputs = np.meshgrid(self.raw, *tensors, indexing=indexing) return tuple(type(self)(out) for out in outputs) def pad( self: TensorType, paddings: Tuple[Tuple[int, int], ...], mode: str = "constant", value: float = 0, ) -> TensorType: if len(paddings) != self.ndim: raise ValueError("pad requires a tuple for each dimension") for p in paddings: if len(p) != 2: raise ValueError("pad requires a tuple for each dimension") if not (mode == "constant" or mode == "reflect"): raise ValueError("pad requires mode 'constant' or 'reflect'") if mode == "reflect": # PyTorch's pad has limited support for 'reflect' padding if self.ndim != 3 and self.ndim != 4: raise NotImplementedError # pragma: no cover k = self.ndim - 2 if paddings[:k] != ((0, 0),) * k: raise NotImplementedError # pragma: no cover if mode == "constant": return type(self)( np.pad(self.raw, paddings, mode=mode, constant_values=value) ) else: return type(self)(np.pad(self.raw, paddings, mode=mode)) def isnan(self: TensorType) -> TensorType: return type(self)(np.isnan(self.raw)) def isinf(self: TensorType) -> TensorType: return type(self)(np.isinf(self.raw)) def crossentropy(self: TensorType, labels: TensorType) -> TensorType: if self.ndim != 2: raise ValueError("crossentropy only supported for 2D logits tensors") if self.shape[:1] != labels.shape: raise ValueError("labels must be 1D and must match the length of logits") # for numerical reasons we subtract the max logit # (mathematically it doesn't matter!) # otherwise exp(logits) might become too large or too small logits = self.raw logits = logits - logits.max(axis=1, keepdims=True) e = np.exp(logits) s = np.sum(e, axis=1) ces = np.log(s) - np.take_along_axis( logits, labels.raw[:, np.newaxis], axis=1 ).squeeze(axis=1) return type(self)(ces) def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]: sign, logabsdet = np.linalg.slogdet(self.raw) return type(self)(sign), type(self)(logabsdet) @overload def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType], has_aux: Literal[False] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True], ) -> Callable[..., Tuple[TensorType, Any, TensorType]]: ... def _value_and_grad_fn( # noqa: F811 (waiting for pyflakes > 2.1.1) self: TensorType, f: Callable, has_aux: bool = False ) -> Callable[..., Tuple]: # f takes and returns JAXTensor instances # jax.value_and_grad accepts functions that take JAXTensor instances # because we registered JAXTensor as JAX type, but it still requires # the output to be a scalar (that is not not wrapped as a JAXTensor) # f_jax is like f but unwraps loss if has_aux: def f_jax(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: loss, aux = f(*args, **kwargs) return loss.raw, aux else: def f_jax(*args: Any, **kwargs: Any) -> Any: # type: ignore loss = f(*args, **kwargs) return loss.raw value_and_grad_jax = jax.value_and_grad(f_jax, has_aux=has_aux) # value_and_grad is like value_and_grad_jax but wraps loss if has_aux: def value_and_grad( x: JAXTensor, *args: Any, **kwargs: Any ) -> Tuple[JAXTensor, Any, JAXTensor]: assert isinstance(x, JAXTensor) (loss, aux), grad = value_and_grad_jax(x, *args, **kwargs) assert grad.shape == x.shape return JAXTensor(loss), aux, grad else: def value_and_grad( # type: ignore x: JAXTensor, *args: Any, **kwargs: Any ) -> Tuple[JAXTensor, JAXTensor]: assert isinstance(x, JAXTensor) loss, grad = value_and_grad_jax(x, *args, **kwargs) assert grad.shape == x.shape return JAXTensor(loss), grad return value_and_grad def sign(self: TensorType) -> TensorType: return type(self)(np.sign(self.raw)) def sqrt(self: TensorType) -> TensorType: return type(self)(np.sqrt(self.raw)) def tanh(self: TensorType) -> TensorType: return type(self)(np.tanh(self.raw)) def float32(self: TensorType) -> TensorType: return self.astype(np.float32) def float64(self: TensorType) -> TensorType: return self.astype(np.float32) def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType: x, y = unwrap_(x, y) return type(self)(np.where(self.raw, x, y)) def __lt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__lt__(unwrap1(other))) def __le__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__le__(unwrap1(other))) def __eq__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__eq__(unwrap1(other))) def __ne__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__ne__(unwrap1(other))) def __gt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__gt__(unwrap1(other))) def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__ge__(unwrap1(other))) def __getitem__(self: TensorType, index: Any) -> TensorType: if isinstance(index, tuple): index = tuple(getitem_preprocess(x) for x in index) else: index = getitem_preprocess(index) return type(self)(self.raw[index]) def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: if axis % self.ndim != self.ndim - 1: raise NotImplementedError( "take_along_axis is currently only supported for the last axis" ) return type(self)(np.take_along_axis(self.raw, index.raw, axis=axis)) def bool(self: TensorType) -> TensorType: return self.astype(np.bool_) eagerpy-0.30.0/eagerpy/tensor/numpy.py000066400000000000000000000400361410374365400177300ustar00rootroot00000000000000from typing import ( Tuple, cast, Union, Any, Iterable, Optional, overload, Callable, TYPE_CHECKING, ) from typing_extensions import Literal import numpy as np from ..types import Axes, AxisAxes, Shape, ShapeOrScalar from .tensor import TensorType from .tensor import Tensor from .tensor import TensorOrScalar from .base import BaseTensor from .base import unwrap_ from .base import unwrap1 if TYPE_CHECKING: from .extensions import NormsMethods # noqa: F401 def assert_bool(x: Any) -> None: if not isinstance(x, Tensor): return if x.dtype != np.dtype("bool"): raise ValueError(f"requires dtype bool, got {x.dtype}, consider t.bool().all()") class NumPyTensor(BaseTensor): __slots__ = () # more specific types for the extensions norms: "NormsMethods[NumPyTensor]" def __init__(self, raw: "np.ndarray"): # type: ignore super().__init__(raw) @property def raw(self) -> "np.ndarray": # type: ignore return super().raw def numpy(self: TensorType) -> Any: a = self.raw.view() if a.flags.writeable: # without the check, we would attempt to set it on array # scalars, and that would fail a.flags.writeable = False return a def item(self) -> Union[int, float, bool]: return self.raw.item() # type: ignore @property def shape(self: TensorType) -> Shape: return cast(Tuple, self.raw.shape) def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: if isinstance(shape, int): shape = (shape,) return type(self)(self.raw.reshape(shape)) def astype(self: TensorType, dtype: Any) -> TensorType: return type(self)(self.raw.astype(dtype)) def clip(self: TensorType, min_: float, max_: float) -> TensorType: return type(self)(np.clip(self.raw, min_, max_)) def square(self: TensorType) -> TensorType: return type(self)(np.square(self.raw)) def arctanh(self: TensorType) -> TensorType: return type(self)(np.arctanh(self.raw)) def sum( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.sum(axis=axis, keepdims=keepdims)) def prod( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.prod(axis=axis, keepdims=keepdims)) def mean( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if self.raw.dtype not in [np.float16, np.float32, np.float64]: raise ValueError( f"Can only calculate the mean of floating types. Got {self.raw.dtype} instead." ) return type(self)(self.raw.mean(axis=axis, keepdims=keepdims)) def min( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.min(axis=axis, keepdims=keepdims)) def max( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(self.raw.max(axis=axis, keepdims=keepdims)) def minimum(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(np.minimum(self.raw, unwrap1(other))) def maximum(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(np.maximum(self.raw, unwrap1(other))) def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.argmin(axis=axis)) def argmax(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.argmax(axis=axis)) def argsort(self: TensorType, axis: int = -1) -> TensorType: return type(self)(self.raw.argsort(axis=axis)) def sort(self: TensorType, axis: int = -1) -> TensorType: return type(self)(np.sort(self.raw, axis=axis)) def topk( self: TensorType, k: int, sorted: bool = True ) -> Tuple[TensorType, TensorType]: idx = np.take(np.argpartition(self.raw, k - 1), np.arange(-k, 0), axis=-1) val = np.take_along_axis(self.raw, idx, axis=-1) if sorted: perm = np.flip(np.argsort(val, axis=-1), axis=-1) idx = np.take_along_axis(idx, perm, axis=-1) val = np.take_along_axis(self.raw, idx, axis=-1) return type(self)(val), type(self)(idx) def uniform( self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0 ) -> TensorType: return type(self)(np.random.uniform(low, high, size=shape)) def normal( self: TensorType, shape: ShapeOrScalar, mean: float = 0.0, stddev: float = 1.0 ) -> TensorType: return type(self)(np.random.normal(mean, stddev, size=shape)) def ones(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)(np.ones(shape, dtype=self.raw.dtype)) def zeros(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)(np.zeros(shape, dtype=self.raw.dtype)) def ones_like(self: TensorType) -> TensorType: return type(self)(np.ones_like(self.raw)) def zeros_like(self: TensorType) -> TensorType: return type(self)(np.zeros_like(self.raw)) def full_like(self: TensorType, fill_value: float) -> TensorType: return type(self)(np.full_like(self.raw, fill_value)) def onehot_like( self: TensorType, indices: TensorType, *, value: float = 1 ) -> TensorType: if self.ndim != 2: raise ValueError("onehot_like only supported for 2D tensors") if indices.ndim != 1: raise ValueError("onehot_like requires 1D indices") if len(indices) != len(self): raise ValueError("length of indices must match length of tensor") x = np.zeros_like(self.raw) rows = np.arange(len(x)) x[rows, indices.raw] = value return type(self)(x) def from_numpy(self: TensorType, a: Any) -> TensorType: return type(self)(np.asarray(a)) def _concatenate( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # concatenates only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(np.concatenate(tensors_, axis=axis)) def _stack( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # stacks only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(np.stack(tensors_, axis=axis)) def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: if axes is None: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(np.transpose(self.raw, axes=axes)) def all( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) return type(self)(self.raw.all(axis=axis, keepdims=keepdims)) def any( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) return type(self)(self.raw.any(axis=axis, keepdims=keepdims)) def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(np.logical_and(self.raw, unwrap1(other))) def logical_or(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(np.logical_or(self.raw, unwrap1(other))) def logical_not(self: TensorType) -> TensorType: assert_bool(self) return type(self)(np.logical_not(self.raw)) def exp(self: TensorType) -> TensorType: return type(self)(np.exp(self.raw)) def log(self: TensorType) -> TensorType: return type(self)(np.log(self.raw)) def log2(self: TensorType) -> TensorType: return type(self)(np.log2(self.raw)) def log10(self: TensorType) -> TensorType: return type(self)(np.log10(self.raw)) def log1p(self: TensorType) -> TensorType: return type(self)(np.log1p(self.raw)) def tile(self: TensorType, multiples: Axes) -> TensorType: multiples = unwrap1(multiples) if len(multiples) != self.ndim: raise ValueError("multiples requires one entry for each dimension") return type(self)(np.tile(self.raw, multiples)) def softmax(self: TensorType, axis: int = -1) -> TensorType: # for numerical reasons we subtract the max logit # (mathematically it doesn't matter!) # otherwise exp(logits) might become too large or too small logits = self.raw logits = logits - logits.max(axis=axis, keepdims=True) e = np.exp(logits) return type(self)(e / e.sum(axis=axis, keepdims=True)) def log_softmax(self: TensorType, axis: int = -1) -> TensorType: # for numerical reasons we subtract the max logit # (mathematically it doesn't matter!) # otherwise exp(logits) might become too large or too small logits = self.raw logits = logits - logits.max(axis=axis, keepdims=True) log_sum_exp = np.log(np.exp(logits).sum(axis=axis, keepdims=True)) return type(self)(logits - log_sum_exp) def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: return type(self)(self.raw.squeeze(axis=axis)) def expand_dims(self: TensorType, axis: int) -> TensorType: return type(self)(np.expand_dims(self.raw, axis=axis)) def full(self: TensorType, shape: ShapeOrScalar, value: float) -> TensorType: return type(self)(np.full(shape, value, dtype=self.raw.dtype)) def index_update( self: TensorType, indices: Any, values: TensorOrScalar ) -> TensorType: indices, values = unwrap_(indices, values) if isinstance(indices, tuple): indices = unwrap_(*indices) x = self.raw.copy() x[indices] = values return type(self)(x) def arange( self: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None, ) -> TensorType: return type(self)(np.arange(start, stop, step)) def cumsum(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.cumsum(axis=axis)) def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: return type(self)(np.flip(self.raw, axis=axis)) def meshgrid( self: TensorType, *tensors: TensorType, indexing: str = "xy" ) -> Tuple[TensorType, ...]: tensors = unwrap_(*tensors) outputs = np.meshgrid(self.raw, *tensors, indexing=indexing) return tuple(type(self)(out) for out in outputs) def pad( self: TensorType, paddings: Tuple[Tuple[int, int], ...], mode: str = "constant", value: float = 0, ) -> TensorType: if len(paddings) != self.ndim: raise ValueError("pad requires a tuple for each dimension") for p in paddings: if len(p) != 2: raise ValueError("pad requires a tuple for each dimension") if not (mode == "constant" or mode == "reflect"): raise ValueError("pad requires mode 'constant' or 'reflect'") if mode == "reflect": # PyTorch's pad has limited support for 'reflect' padding if self.ndim != 3 and self.ndim != 4: raise NotImplementedError # pragma: no cover k = self.ndim - 2 if paddings[:k] != ((0, 0),) * k: raise NotImplementedError # pragma: no cover if mode == "constant": return type(self)( np.pad(self.raw, paddings, mode=mode, constant_values=value) ) else: return type(self)(np.pad(self.raw, paddings, mode=mode)) def isnan(self: TensorType) -> TensorType: return type(self)(np.isnan(self.raw)) def isinf(self: TensorType) -> TensorType: return type(self)(np.isinf(self.raw)) def crossentropy(self: TensorType, labels: TensorType) -> TensorType: if self.ndim != 2: raise ValueError("crossentropy only supported for 2D logits tensors") if self.shape[:1] != labels.shape: raise ValueError("labels must be 1D and must match the length of logits") # for numerical reasons we subtract the max logit # (mathematically it doesn't matter!) # otherwise exp(logits) might become too large or too small logits = self.raw logits = logits - logits.max(axis=1, keepdims=True) e = np.exp(logits) s = np.sum(e, axis=1) ces = np.log(s) - np.take_along_axis( logits, labels.raw[:, np.newaxis], axis=1 ).squeeze(axis=1) return type(self)(ces) def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]: sign, logabsdet = np.linalg.slogdet(self.raw) return type(self)(sign), type(self)(logabsdet) @overload def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType], has_aux: Literal[False] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True], ) -> Callable[..., Tuple[TensorType, Any, TensorType]]: ... def _value_and_grad_fn( # noqa: F811 (waiting for pyflakes > 2.1.1) self: TensorType, f: Callable, has_aux: bool = False ) -> Callable[..., Tuple]: # TODO: maybe implement this using https://github.com/HIPS/autograd raise NotImplementedError # pragma: no cover def sign(self: TensorType) -> TensorType: return type(self)(np.sign(self.raw)) def sqrt(self: TensorType) -> TensorType: return type(self)(np.sqrt(self.raw)) def tanh(self: TensorType) -> TensorType: return type(self)(np.tanh(self.raw)) def float32(self: TensorType) -> TensorType: return self.astype(np.float32) def float64(self: TensorType) -> TensorType: return self.astype(np.float64) def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType: x, y = unwrap_(x, y) return type(self)(np.where(self.raw, x, y)) def __lt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__lt__(unwrap1(other))) def __le__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__le__(unwrap1(other))) def __eq__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__eq__(unwrap1(other))) def __ne__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__ne__(unwrap1(other))) def __gt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__gt__(unwrap1(other))) def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__ge__(unwrap1(other))) def __getitem__(self: TensorType, index: Any) -> TensorType: if isinstance(index, tuple): index = tuple(x.raw if isinstance(x, Tensor) else x for x in index) elif isinstance(index, Tensor): index = index.raw return type(self)(self.raw[index]) def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: if axis % self.ndim != self.ndim - 1: raise NotImplementedError( "take_along_axis is currently only supported for the last axis" ) return type(self)(np.take_along_axis(self.raw, index.raw, axis=axis)) def bool(self: TensorType) -> TensorType: return self.astype(np.dtype("bool")) eagerpy-0.30.0/eagerpy/tensor/pytorch.py000066400000000000000000000522151410374365400202520ustar00rootroot00000000000000from typing import ( Tuple, cast, Union, Any, TypeVar, TYPE_CHECKING, Iterable, Optional, overload, Callable, ) from typing_extensions import Literal import numpy as np from importlib import import_module from ..types import Axes, AxisAxes, Shape, ShapeOrScalar from .tensor import Tensor from .tensor import TensorOrScalar from .base import BaseTensor from .base import unwrap_ from .base import unwrap1 if TYPE_CHECKING: import torch # for static analyzers from .extensions import NormsMethods # noqa: F401 else: # lazy import in PyTorchTensor torch = None # stricter TensorType to get additional type information from the raw method TensorType = TypeVar("TensorType", bound="PyTorchTensor") def assert_bool(x: Any) -> None: if not isinstance(x, Tensor): return if x.dtype != torch.bool: raise ValueError(f"requires dtype bool, got {x.dtype}, consider t.bool().all()") class PyTorchTensor(BaseTensor): __slots__ = () # more specific types for the extensions norms: "NormsMethods[PyTorchTensor]" def __init__(self, raw: "torch.Tensor"): global torch if torch is None: torch = import_module("torch") # type: ignore super().__init__(raw) @property def raw(self) -> "torch.Tensor": return cast(torch.Tensor, super().raw) def tanh(self: TensorType) -> TensorType: return type(self)(torch.tanh(self.raw)) def numpy(self: TensorType) -> Any: a = self.raw.detach().cpu().numpy() if a.flags.writeable: # without the check, we would attempt to set it on array # scalars, and that would fail a.flags.writeable = False return a def item(self) -> Union[int, float, bool]: return self.raw.item() @property def shape(self) -> Shape: return self.raw.shape def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: if isinstance(shape, int): shape = (shape,) return type(self)(self.raw.reshape(shape)) def astype(self: TensorType, dtype: Any) -> TensorType: return type(self)(self.raw.to(dtype)) def clip(self: TensorType, min_: float, max_: float) -> TensorType: return type(self)(self.raw.clamp(min_, max_)) def square(self: TensorType) -> TensorType: return type(self)(self.raw ** 2) def arctanh(self: TensorType) -> TensorType: """ improve once this issue has been fixed: https://github.com/pytorch/pytorch/issues/10324 """ return type(self)(0.5 * (torch.log1p(self.raw) - torch.log1p(-self.raw))) def sum( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if axis is None and not keepdims: return type(self)(self.raw.sum()) if axis is None: axis = tuple(range(self.ndim)) return type(self)(self.raw.sum(dim=axis, keepdim=keepdims)) def prod( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if axis is None and not keepdims: return type(self)(self.raw.prod()) if axis is None: axis = tuple(range(self.ndim)) elif not isinstance(axis, Iterable): axis = (axis,) x = self.raw for i in sorted(axis, reverse=True): x = x.prod(i, keepdim=keepdims) return type(self)(x) def mean( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if self.raw.dtype not in [torch.float16, torch.float32, torch.float64]: raise ValueError( f"Can only calculate the mean of floating types. Got {self.raw.dtype} instead." ) if axis is None and not keepdims: return type(self)(self.raw.mean()) if axis is None: axis = tuple(range(self.ndim)) return type(self)(self.raw.mean(dim=axis, keepdim=keepdims)) def min( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: """ simplify once this issue has been fixed: https://github.com/pytorch/pytorch/issues/28213 """ if axis is None and not keepdims: return type(self)(self.raw.min()) if axis is None: axis = tuple(range(self.ndim)) elif not isinstance(axis, Iterable): axis = (axis,) x = self.raw for i in sorted(axis, reverse=True): x, _ = x.min(i, keepdim=keepdims) return type(self)(x) def max( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: """ simplify once this issue has been fixed: https://github.com/pytorch/pytorch/issues/28213 """ if axis is None and not keepdims: return type(self)(self.raw.max()) if axis is None: axis = tuple(range(self.ndim)) elif not isinstance(axis, Iterable): axis = (axis,) x = self.raw for i in sorted(axis, reverse=True): x, _ = x.max(i, keepdim=keepdims) return type(self)(x) def minimum(self: TensorType, other: TensorOrScalar) -> TensorType: if isinstance(other, Tensor): other_ = other.raw elif isinstance(other, int) or isinstance(other, float): other_ = torch.full_like(self.raw, other) else: raise TypeError( "expected x to be a Tensor, int or float" ) # pragma: no cover return type(self)(torch.min(self.raw, other_)) def maximum(self: TensorType, other: TensorOrScalar) -> TensorType: if isinstance(other, Tensor): other_ = other.raw elif isinstance(other, int) or isinstance(other, float): other_ = torch.full_like(self.raw, other) else: raise TypeError( "expected x to be a Tensor, int or float" ) # pragma: no cover return type(self)(torch.max(self.raw, other_)) def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.argmin(dim=axis)) def argmax(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(self.raw.argmax(dim=axis)) def argsort(self: TensorType, axis: int = -1) -> TensorType: return type(self)(self.raw.argsort(dim=axis)) def sort(self: TensorType, axis: int = -1) -> TensorType: return type(self)(self.raw.sort(dim=axis).values) # type: ignore def topk( self: TensorType, k: int, sorted: bool = True ) -> Tuple[TensorType, TensorType]: values, indices = self.raw.topk(k, sorted=sorted) return type(self)(values), type(self)(indices) def uniform( self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0 ) -> TensorType: return type(self)( torch.rand(shape, dtype=self.raw.dtype, device=self.raw.device) * (high - low) + low ) def normal( self: TensorType, shape: ShapeOrScalar, mean: float = 0.0, stddev: float = 1.0 ) -> TensorType: return type(self)( torch.randn(shape, dtype=self.raw.dtype, device=self.raw.device) * stddev + mean ) def ones(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)( torch.ones(shape, dtype=self.raw.dtype, device=self.raw.device) ) def zeros(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)( torch.zeros(shape, dtype=self.raw.dtype, device=self.raw.device) ) def ones_like(self: TensorType) -> TensorType: return type(self)(torch.ones_like(self.raw)) def zeros_like(self: TensorType) -> TensorType: return type(self)(torch.zeros_like(self.raw)) def full_like(self: TensorType, fill_value: float) -> TensorType: return type(self)(torch.full_like(self.raw, fill_value)) def onehot_like( self: TensorType, indices: TensorType, *, value: float = 1 ) -> TensorType: if self.ndim != 2: raise ValueError("onehot_like only supported for 2D tensors") if indices.ndim != 1: raise ValueError("onehot_like requires 1D indices") if len(indices) != len(self): raise ValueError("length of indices must match length of tensor") x = torch.zeros_like(self.raw) rows = np.arange(x.shape[0]) x[rows, indices.raw] = value return type(self)(x) def from_numpy(self: TensorType, a: Any) -> TensorType: return type(self)(torch.as_tensor(a, device=self.raw.device)) def _concatenate( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # concatenates only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(torch.cat(tensors_, dim=axis)) def _stack( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # stacks only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(torch.stack(tensors_, dim=axis)) def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: if axes is None: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(self.raw.permute(*axes)) def all( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) if axis is None and not keepdims: return type(self)(self.raw.all()) if axis is None: axis = tuple(range(self.ndim)) elif not isinstance(axis, Iterable): axis = (axis,) x = self.raw for i in sorted(axis, reverse=True): x = x.all(i, keepdim=keepdims) return type(self)(x) def any( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) if axis is None and not keepdims: return type(self)(self.raw.any()) if axis is None: axis = tuple(range(self.ndim)) elif not isinstance(axis, Iterable): axis = (axis,) x = self.raw for i in sorted(axis, reverse=True): x = x.any(i, keepdim=keepdims) return type(self)(x) def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(self.raw & unwrap1(other)) def logical_or(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(self.raw | unwrap1(other)) def logical_not(self: TensorType) -> TensorType: assert_bool(self) return type(self)(~self.raw) def exp(self: TensorType) -> TensorType: return type(self)(torch.exp(self.raw)) def log(self: TensorType) -> TensorType: return type(self)(torch.log(self.raw)) def log2(self: TensorType) -> TensorType: return type(self)(torch.log2(self.raw)) def log10(self: TensorType) -> TensorType: return type(self)(torch.log10(self.raw)) def log1p(self: TensorType) -> TensorType: return type(self)(torch.log1p(self.raw)) def tile(self: TensorType, multiples: Axes) -> TensorType: if len(multiples) != self.ndim: raise ValueError("multiples requires one entry for each dimension") return type(self)(self.raw.repeat(multiples)) def softmax(self: TensorType, axis: int = -1) -> TensorType: return type(self)(torch.nn.functional.softmax(self.raw, dim=axis)) def log_softmax(self: TensorType, axis: int = -1) -> TensorType: return type(self)(torch.nn.functional.log_softmax(self.raw, dim=axis)) def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: if axis is None: return type(self)(self.raw.squeeze()) if not isinstance(axis, Iterable): axis = (axis,) x = self.raw for i in sorted(axis, reverse=True): if x.shape[i] != 1: raise ValueError( "cannot select an axis to squeeze out which has size not equal to one" ) x = x.squeeze(dim=i) return type(self)(x) def expand_dims(self: TensorType, axis: int) -> TensorType: return type(self)(self.raw.unsqueeze(dim=axis)) def full(self: TensorType, shape: ShapeOrScalar, value: float) -> TensorType: if not isinstance(shape, Iterable): shape = (shape,) return type(self)( torch.full(shape, value, dtype=self.raw.dtype, device=self.raw.device) ) def index_update( self: TensorType, indices: Any, values: TensorOrScalar ) -> TensorType: indices, values_ = unwrap_(indices, values) if isinstance(indices, tuple): indices = unwrap_(*indices) x = self.raw.clone() x[indices] = values_ return type(self)(x) def arange( self: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None, ) -> TensorType: if step is None: step = 1 if stop is None: stop = start start = 0 return type(self)( torch.arange(start=start, end=stop, step=step, device=self.raw.device) ) def cumsum(self: TensorType, axis: Optional[int] = None) -> TensorType: if axis is None: return type(self)(self.raw.reshape(-1).cumsum(dim=0)) return type(self)(self.raw.cumsum(dim=axis)) def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: if axis is None: axis = tuple(range(self.ndim)) if not isinstance(axis, Iterable): axis = (axis,) return type(self)(self.raw.flip(dims=axis)) def meshgrid( self: TensorType, *tensors: TensorType, indexing: str = "xy" ) -> Tuple[TensorType, ...]: tensors = unwrap_(*tensors) if indexing == "ij" or len(tensors) == 0: outputs = torch.meshgrid(self.raw, *tensors) # type: ignore elif indexing == "xy": outputs = torch.meshgrid(tensors[0], self.raw, *tensors[1:]) # type: ignore else: raise ValueError( # pragma: no cover f"Valid values for indexing are 'xy' and 'ij', got {indexing}" ) results = [type(self)(out) for out in outputs] if indexing == "xy" and len(results) >= 2: results[0], results[1] = results[1], results[0] return tuple(results) def pad( self: TensorType, paddings: Tuple[Tuple[int, int], ...], mode: str = "constant", value: float = 0, ) -> TensorType: if len(paddings) != self.ndim: raise ValueError("pad requires a tuple for each dimension") for p in paddings: if len(p) != 2: raise ValueError("pad requires a tuple for each dimension") if not (mode == "constant" or mode == "reflect"): raise ValueError("pad requires mode 'constant' or 'reflect'") if mode == "reflect": # PyTorch's pad has limited support for 'reflect' padding if self.ndim != 3 and self.ndim != 4: raise NotImplementedError # pragma: no cover k = self.ndim - 2 if paddings[:k] != ((0, 0),) * k: raise NotImplementedError # pragma: no cover paddings = paddings[k:] paddings_ = list(x for p in reversed(paddings) for x in p) return type(self)( torch.nn.functional.pad(self.raw, paddings_, mode=mode, value=value) ) def isnan(self: TensorType) -> TensorType: return type(self)(torch.isnan(self.raw)) def isinf(self: TensorType) -> TensorType: return type(self)(torch.isinf(self.raw)) # type: ignore def crossentropy(self: TensorType, labels: TensorType) -> TensorType: if self.ndim != 2: raise ValueError("crossentropy only supported for 2D logits tensors") if self.shape[:1] != labels.shape: raise ValueError("labels must be 1D and must match the length of logits") return type(self)( torch.nn.functional.cross_entropy(self.raw, labels.raw, reduction="none") ) def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]: sign, logabsdet = torch.slogdet(self.raw) return type(self)(sign), type(self)(logabsdet) @overload def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType], has_aux: Literal[False] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True], ) -> Callable[..., Tuple[TensorType, Any, TensorType]]: ... def _value_and_grad_fn( # noqa: F811 (waiting for pyflakes > 2.1.1) self: TensorType, f: Callable, has_aux: bool = False ) -> Callable[..., Tuple]: def value_and_grad(x: TensorType, *args: Any, **kwargs: Any) -> Tuple: x = type(self)(x.raw.clone().requires_grad_()) if has_aux: loss, aux = f(x, *args, **kwargs) else: loss = f(x, *args, **kwargs) loss = loss.raw loss.backward() assert x.raw.grad is not None grad = type(self)(x.raw.grad) assert grad.shape == x.shape loss = loss.detach() loss = type(self)(loss) if has_aux: if isinstance(aux, PyTorchTensor): aux = PyTorchTensor(aux.raw.detach()) elif isinstance(aux, tuple): aux = tuple( PyTorchTensor(t.raw.detach()) if isinstance(t, PyTorchTensor) else t for t in aux ) return loss, aux, grad else: return loss, grad return value_and_grad def sign(self: TensorType) -> TensorType: return type(self)(torch.sign(self.raw)) def sqrt(self: TensorType) -> TensorType: return type(self)(torch.sqrt(self.raw)) def float32(self: TensorType) -> TensorType: return self.astype(torch.float32) def float64(self: TensorType) -> TensorType: return self.astype(torch.float64) def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType: if isinstance(x, Tensor): x_ = x.raw elif isinstance(x, int) or isinstance(x, float): if isinstance(y, Tensor): dtype = y.raw.dtype else: dtype = torch.float32 x_ = torch.full_like(self.raw, x, dtype=dtype) else: raise TypeError( "expected x to be a Tensor, int or float" ) # pragma: no cover if isinstance(y, Tensor): y_ = y.raw elif isinstance(y, int) or isinstance(y, float): if isinstance(x, Tensor): dtype = x.raw.dtype else: dtype = torch.float32 y_ = torch.full_like(self.raw, y, dtype=dtype) return type(self)(torch.where(self.raw, x_, y_)) def __lt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__lt__(unwrap1(other))) def __le__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__le__(unwrap1(other))) def __eq__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__eq__(unwrap1(other))) def __ne__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__ne__(unwrap1(other))) def __gt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__gt__(unwrap1(other))) def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__ge__(unwrap1(other))) def __getitem__(self: TensorType, index: Any) -> TensorType: if isinstance(index, tuple): index = tuple(x.raw if isinstance(x, Tensor) else x for x in index) elif isinstance(index, Tensor): index = index.raw return type(self)(self.raw[index]) def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: if axis % self.ndim != self.ndim - 1: raise NotImplementedError( "take_along_axis is currently only supported for the last axis" ) return type(self)(torch.gather(self.raw, axis, index.raw)) def bool(self: TensorType) -> TensorType: return self.astype(torch.bool) eagerpy-0.30.0/eagerpy/tensor/tensor.py000066400000000000000000000362451410374365400201010ustar00rootroot00000000000000from abc import ABCMeta, abstractmethod from typing import ( TypeVar, Callable, Tuple, Any, overload, Iterable, Iterator, Union, Optional, Type, TYPE_CHECKING, cast, ) from typing_extensions import Literal, final from ..types import Axes, AxisAxes, Shape, ShapeOrScalar if TYPE_CHECKING: from .extensions import NormsMethods # noqa: F401 TensorType = TypeVar("TensorType", bound="Tensor") # using Tensor instead of TensorType because of a MyPy bug # https://github.com/python/mypy/issues/3644 TensorOrScalar = Union["Tensor", int, float] class LazyCachedAccessor: # supports caching under a different name (because Tensor uses __slots__ # and thus we cannot override the LazyCachedAccessor class var intself) # supports lazy extension loading to break cyclic dependencies def __init__(self, cache_name: str, extension_name: str): self._cache_name = cache_name self._extension_name = extension_name @property def _extension(self) -> Any: # Type[object]: # only imported once needed to break cyclic dependencies from . import extensions return getattr(extensions, self._extension_name) def __get__( self, instance: Optional["Tensor"], owner: Optional[Type["Tensor"]] = None ) -> Any: if instance is None: # accessed as a class attribute return self._extension methods = getattr(instance, self._cache_name, None) if methods is not None: return methods # create the extension for this instance methods = self._extension(instance) # add it to the instance to avoid recreation instance.__setattr__(self._cache_name, methods) return methods class Tensor(metaclass=ABCMeta): """Base class defining the common interface of all EagerPy Tensors""" # each extension neeeds a slot to cache the instantiated extension __slots__ = ("_norms",) __array_ufunc__ = None # shorten the class name to eagerpy.Tensor (does not help with MyPy) __module__ = "eagerpy" @abstractmethod def __init__(self, raw: Any): ... @property @abstractmethod def raw(self) -> Any: ... @property @abstractmethod def dtype(self: TensorType) -> Any: ... @abstractmethod def __repr__(self: TensorType) -> str: ... @abstractmethod def __format__(self: TensorType, format_spec: str) -> str: ... @abstractmethod def __getitem__(self: TensorType, index: Any) -> TensorType: ... @abstractmethod def __bool__(self: TensorType) -> bool: ... @abstractmethod def __len__(self: TensorType) -> int: ... @abstractmethod def __abs__(self: TensorType) -> TensorType: ... @abstractmethod def __neg__(self: TensorType) -> TensorType: ... @abstractmethod def __add__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __radd__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __sub__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __rsub__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __mul__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __rmul__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __truediv__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __rtruediv__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __floordiv__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __rfloordiv__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __mod__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __lt__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __le__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __eq__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore # we ignore the type errors caused by wrong type annotations for object # https://github.com/python/typeshed/issues/3685 ... @abstractmethod def __ne__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore # we ignore the type errors caused by wrong type annotations for object # https://github.com/python/typeshed/issues/3685 ... @abstractmethod def __gt__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType: ... @abstractmethod def sign(self: TensorType) -> TensorType: ... @abstractmethod def sqrt(self: TensorType) -> TensorType: ... @abstractmethod def tanh(self: TensorType) -> TensorType: ... @abstractmethod def float32(self: TensorType) -> TensorType: ... @abstractmethod def float64(self: TensorType) -> TensorType: ... @abstractmethod def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType: ... @property @abstractmethod def ndim(self: TensorType) -> int: ... @abstractmethod def numpy(self: TensorType) -> Any: ... @abstractmethod def item(self: TensorType) -> Union[int, float, bool]: ... @property @abstractmethod def shape(self: TensorType) -> Shape: ... @abstractmethod def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: ... @abstractmethod def astype(self: TensorType, dtype: Any) -> TensorType: ... @abstractmethod def clip(self: TensorType, min_: float, max_: float) -> TensorType: ... @abstractmethod def square(self: TensorType) -> TensorType: ... @abstractmethod def arctanh(self: TensorType) -> TensorType: ... @abstractmethod def sum( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: ... @abstractmethod def prod( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: ... @abstractmethod def mean( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: ... @abstractmethod def min( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: ... @abstractmethod def max( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: ... @abstractmethod def minimum(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def maximum(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType: ... @abstractmethod def argmax(self: TensorType, axis: Optional[int] = None) -> TensorType: ... @abstractmethod def argsort(self: TensorType, axis: int = -1) -> TensorType: ... @abstractmethod def sort(self: TensorType, axis: int = -1) -> TensorType: ... @abstractmethod def topk( self: TensorType, k: int, sorted: bool = True ) -> Tuple[TensorType, TensorType]: ... @abstractmethod def uniform( self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0 ) -> TensorType: ... @abstractmethod def normal( self: TensorType, shape: ShapeOrScalar, mean: float = 0.0, stddev: float = 1.0 ) -> TensorType: ... @abstractmethod def ones(self: TensorType, shape: ShapeOrScalar) -> TensorType: ... @abstractmethod def zeros(self: TensorType, shape: ShapeOrScalar) -> TensorType: ... @abstractmethod def ones_like(self: TensorType) -> TensorType: ... @abstractmethod def zeros_like(self: TensorType) -> TensorType: ... @abstractmethod def full_like(self: TensorType, fill_value: float) -> TensorType: ... @abstractmethod def onehot_like( self: TensorType, indices: TensorType, *, value: float = 1 ) -> TensorType: ... @abstractmethod def from_numpy(self: TensorType, a: Any) -> TensorType: ... @abstractmethod def _concatenate( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: ... @abstractmethod def _stack( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: ... @abstractmethod def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: ... @abstractmethod def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: ... @abstractmethod def all( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: ... @abstractmethod def any( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: ... @abstractmethod def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def logical_or(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod def logical_not(self: TensorType) -> TensorType: ... @abstractmethod def exp(self: TensorType) -> TensorType: ... @abstractmethod def log(self: TensorType) -> TensorType: ... @abstractmethod def log2(self: TensorType) -> TensorType: ... @abstractmethod def log10(self: TensorType) -> TensorType: ... @abstractmethod def log1p(self: TensorType) -> TensorType: ... @abstractmethod def tile(self: TensorType, multiples: Axes) -> TensorType: ... @abstractmethod def softmax(self: TensorType, axis: int = -1) -> TensorType: ... @abstractmethod def log_softmax(self: TensorType, axis: int = -1) -> TensorType: ... @abstractmethod def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: ... @abstractmethod def expand_dims(self: TensorType, axis: int) -> TensorType: ... @abstractmethod def full(self: TensorType, shape: ShapeOrScalar, value: float) -> TensorType: ... @abstractmethod def index_update( self: TensorType, indices: Any, values: TensorOrScalar ) -> TensorType: ... @abstractmethod def arange( self: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None, ) -> TensorType: ... @abstractmethod def cumsum(self: TensorType, axis: Optional[int] = None) -> TensorType: ... @abstractmethod def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: ... @abstractmethod def meshgrid( self: TensorType, *tensors: TensorType, indexing: str = "xy" ) -> Tuple[TensorType, ...]: ... @abstractmethod def pad( self: TensorType, paddings: Tuple[Tuple[int, int], ...], mode: str = "constant", value: float = 0, ) -> TensorType: ... @abstractmethod def isnan(self: TensorType) -> TensorType: ... @abstractmethod def isinf(self: TensorType) -> TensorType: ... @abstractmethod def crossentropy(self: TensorType, labels: TensorType) -> TensorType: ... @abstractmethod def slogdet(matrix: TensorType) -> Tuple[TensorType, TensorType]: ... @overload def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType], has_aux: Literal[False] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True], ) -> Callable[..., Tuple[TensorType, Any, TensorType]]: ... @abstractmethod # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable, has_aux: bool = False ) -> Callable[..., Tuple]: ... @abstractmethod def bool(self: TensorType) -> TensorType: ... # ######################################################################### # aliases and shared implementations # ######################################################################### @final @property def T(self: TensorType) -> TensorType: return self.transpose() @final def abs(self: TensorType) -> TensorType: return self.__abs__() @final def pow(self: TensorType, exponent: TensorOrScalar) -> TensorType: return self.__pow__(exponent) @final def value_and_grad( self: TensorType, f: Callable[..., TensorType], *args: Any, **kwargs: Any ) -> Tuple[TensorType, TensorType]: return self._value_and_grad_fn(f, has_aux=False)(self, *args, **kwargs) @final def value_aux_and_grad( self: TensorType, f: Callable[..., Tuple[TensorType, Any]], *args: Any, **kwargs: Any, ) -> Tuple[TensorType, Any, TensorType]: return self._value_and_grad_fn(f, has_aux=True)(self, *args, **kwargs) def __iter__(self: TensorType) -> Iterator[TensorType]: for i in range(len(self)): yield self[i] @final def flatten(self: TensorType, start: int = 0, end: int = -1) -> TensorType: start = start % self.ndim end = end % self.ndim shape = self.shape[:start] + (-1,) + self.shape[end + 1 :] return self.reshape(shape) def __matmul__(self: TensorType, other: TensorType) -> TensorType: if self.ndim != 2 or other.ndim != 2: raise ValueError( f"matmul requires both tensors to be 2D, got {self.ndim}D and {other.ndim}D" ) return type(self)(self.raw.__matmul__(other.raw)) def matmul(self: TensorType, other: TensorType) -> TensorType: return self.__matmul__(other) # ######################################################################### # extensions # ######################################################################### norms = cast("NormsMethods[Tensor]", LazyCachedAccessor("_norms", "NormsMethods")) def istensor(x: Any) -> bool: return isinstance(x, Tensor) eagerpy-0.30.0/eagerpy/tensor/tensorflow.py000066400000000000000000000527161410374365400207720ustar00rootroot00000000000000from typing import ( Tuple, cast, Union, Any, TypeVar, TYPE_CHECKING, Iterable, Optional, overload, Callable, ) from typing_extensions import Literal import numpy as np from importlib import import_module import functools from ..types import Axes, AxisAxes, Shape, ShapeOrScalar from .. import index from .tensor import Tensor from .tensor import TensorOrScalar from .tensor import TensorType from .base import BaseTensor from .base import unwrap_ from .base import unwrap1 if TYPE_CHECKING: import tensorflow as tf # for static analyzers from .extensions import NormsMethods # noqa: F401 else: # lazy import in TensorFlowTensor tf = None FuncType = Callable[..., Any] F = TypeVar("F", bound=FuncType) def samedevice(f: F) -> F: @functools.wraps(f) def wrapper(self: "TensorFlowTensor", *args: Any, **kwargs: Any) -> Any: with tf.device(self.raw.device): return f(self, *args, **kwargs) return cast(F, wrapper) def common_dtype(f: F) -> F: @functools.wraps(f) def wrapper(self: "TensorFlowTensor", *args: Any, **kwargs: Any) -> Any: dtypes = {self.dtype} | {arg.dtype for arg in args if isinstance(arg, Tensor)} if len(dtypes) == 1: # all dtypes are the same, nothing more to do return f(self, *args, **kwargs) numpy_dtypes = [np.dtype(dtype.name) for dtype in dtypes] common = np.find_common_type(numpy_dtypes, []) common = getattr(tf, common.name) if self.dtype != common: self = self.astype(common) args = tuple( arg.astype(common) if isinstance(arg, Tensor) and arg.dtype != common else arg for arg in args ) return f(self, *args, **kwargs) return cast(F, wrapper) def assert_bool(x: Any) -> None: if not isinstance(x, Tensor): return if x.dtype != tf.bool: raise ValueError(f"requires dtype bool, got {x.dtype}, consider t.bool().all()") class TensorFlowTensor(BaseTensor): __slots__ = () # more specific types for the extensions norms: "NormsMethods[TensorFlowTensor]" def __init__(self, raw: "tf.Tensor"): # type: ignore global tf if tf is None: tf = import_module("tensorflow") super().__init__(raw) @property def raw(self) -> "tf.Tensor": # type: ignore return super().raw def numpy(self: TensorType) -> Any: a = self.raw.numpy() if a.flags.writeable: # without the check, we would attempt to set it on array # scalars, and that would fail a.flags.writeable = False return a def item(self: TensorType) -> Union[int, float, bool]: return self.numpy().item() # type: ignore @property def shape(self: TensorType) -> Shape: return tuple(self.raw.shape.as_list()) def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: if isinstance(shape, int): shape = (shape,) return type(self)(tf.reshape(self.raw, shape)) def astype(self: TensorType, dtype: Any) -> TensorType: return type(self)(tf.cast(self.raw, dtype)) def clip(self: TensorType, min_: float, max_: float) -> TensorType: return type(self)(tf.clip_by_value(self.raw, min_, max_)) def square(self: TensorType) -> TensorType: return type(self)(tf.square(self.raw)) def arctanh(self: TensorType) -> TensorType: return type(self)(tf.atanh(self.raw)) def sum( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if self.raw.dtype == tf.bool: return self.astype(tf.int64).sum(axis=axis, keepdims=keepdims) return type(self)(tf.reduce_sum(self.raw, axis=axis, keepdims=keepdims)) def prod( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if self.raw.dtype == tf.bool: return self.astype(tf.int64).prod(axis=axis, keepdims=keepdims) return type(self)(tf.reduce_prod(self.raw, axis=axis, keepdims=keepdims)) def mean( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: if self.raw.dtype not in [tf.float16, tf.float32, tf.float64]: raise ValueError( f"Can only calculate the mean of floating types. Got {self.raw.dtype} instead." ) return type(self)(tf.reduce_mean(self.raw, axis=axis, keepdims=keepdims)) def min( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(tf.reduce_min(self.raw, axis=axis, keepdims=keepdims)) def max( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: return type(self)(tf.reduce_max(self.raw, axis=axis, keepdims=keepdims)) def minimum(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(tf.minimum(self.raw, unwrap1(other))) def maximum(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(tf.maximum(self.raw, unwrap1(other))) def argmin(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(tf.argmin(self.raw, axis=axis)) def argmax(self: TensorType, axis: Optional[int] = None) -> TensorType: return type(self)(tf.argmax(self.raw, axis=axis)) def argsort(self: TensorType, axis: Optional[int] = -1) -> TensorType: return type(self)(tf.argsort(self.raw, axis=axis)) def sort(self: TensorType, axis: Optional[int] = -1) -> TensorType: return type(self)(tf.sort(self.raw, axis=axis)) def topk( self: TensorType, k: int, sorted: bool = True ) -> Tuple[TensorType, TensorType]: values, indices = tf.math.top_k(self.raw, k, sorted=sorted) return type(self)(values), type(self)(indices) @samedevice def uniform( self: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0 ) -> TensorType: if not isinstance(shape, Iterable): shape = (shape,) return type(self)( tf.random.uniform(shape, minval=low, maxval=high, dtype=self.raw.dtype) ) @samedevice def normal( self: TensorType, shape: ShapeOrScalar, mean: float = 0.0, stddev: float = 1.0 ) -> TensorType: if not isinstance(shape, Iterable): shape = (shape,) return type(self)( tf.random.normal(shape, mean=mean, stddev=stddev, dtype=self.raw.dtype) ) @samedevice def ones(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)(tf.ones(shape, dtype=self.raw.dtype)) @samedevice def zeros(self: TensorType, shape: ShapeOrScalar) -> TensorType: return type(self)(tf.zeros(shape, dtype=self.raw.dtype)) def ones_like(self: TensorType) -> TensorType: return type(self)(tf.ones_like(self.raw)) def zeros_like(self: TensorType) -> TensorType: return type(self)(tf.zeros_like(self.raw)) def full_like(self: TensorType, fill_value: float) -> TensorType: fill_value = tf.cast(fill_value, self.raw.dtype) return type(self)(tf.fill(self.raw.shape, fill_value)) @samedevice def onehot_like( self: TensorType, indices: TensorType, *, value: float = 1 ) -> TensorType: if self.ndim != 2: raise ValueError("onehot_like only supported for 2D tensors") if indices.ndim != 1: raise ValueError("onehot_like requires 1D indices") if len(indices) != len(self): raise ValueError("length of indices must match length of tensor") value = tf.cast(value, self.raw.dtype) return type(self)( tf.one_hot( indices.raw, depth=self.raw.shape[-1], on_value=value, dtype=self.raw.dtype, ) ) @samedevice def from_numpy(self: TensorType, a: Any) -> TensorType: return type(self)(tf.convert_to_tensor(a)) def _concatenate( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # concatenates only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(tf.concat(tensors_, axis=axis)) def _stack( self: TensorType, tensors: Iterable[TensorType], axis: int = 0 ) -> TensorType: # stacks only "tensors", but not "self" tensors_ = unwrap_(*tensors) return type(self)(tf.stack(tensors_, axis=axis)) def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: if axes is None: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(tf.transpose(self.raw, perm=axes)) def all( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) return type(self)(tf.reduce_all(self.raw, axis=axis, keepdims=keepdims)) def any( self: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: assert_bool(self) return type(self)(tf.reduce_any(self.raw, axis=axis, keepdims=keepdims)) def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(tf.logical_and(self.raw, unwrap1(other))) def logical_or(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(tf.logical_or(self.raw, unwrap1(other))) def logical_not(self: TensorType) -> TensorType: assert_bool(self) return type(self)(tf.logical_not(self.raw)) def exp(self: TensorType) -> TensorType: return type(self)(tf.exp(self.raw)) def log(self: TensorType) -> TensorType: return type(self)(tf.math.log(self.raw)) def log2(self: TensorType) -> TensorType: return type(self)(tf.math.log(self.raw) / tf.math.log(2.0)) def log10(self: TensorType) -> TensorType: return type(self)(tf.math.log(self.raw) / tf.math.log(10.0)) def log1p(self: TensorType) -> TensorType: return type(self)(tf.math.log1p(self.raw)) def tile(self: TensorType, multiples: Axes) -> TensorType: multiples = unwrap1(multiples) if len(multiples) != self.ndim: raise ValueError("multiples requires one entry for each dimension") return type(self)(tf.tile(self.raw, multiples)) def softmax(self: TensorType, axis: int = -1) -> TensorType: return type(self)(tf.nn.softmax(self.raw, axis=axis)) def log_softmax(self: TensorType, axis: int = -1) -> TensorType: return type(self)(tf.nn.log_softmax(self.raw, axis=axis)) def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: return type(self)(tf.squeeze(self.raw, axis=axis)) def expand_dims(self: TensorType, axis: int) -> TensorType: return type(self)(tf.expand_dims(self.raw, axis=axis)) @samedevice def full(self: TensorType, shape: ShapeOrScalar, value: float) -> TensorType: if not isinstance(shape, Iterable): shape = (shape,) return type(self)(tf.fill(shape, value)) def index_update( self: TensorType, indices: Any, values: TensorOrScalar ) -> TensorType: indices, values_ = unwrap_(indices, values) del values if isinstance(indices, tuple): indices = unwrap_(*indices) x = self.raw if isinstance(indices, int): if isinstance(values_, int) or isinstance(values_, float): values_ = tf.fill(x.shape[-1:], values_) return type(self)( tf.tensor_scatter_nd_update(x, [[indices]], values_[None]) ) elif isinstance(indices, tuple) and any( isinstance(idx, slice) for idx in indices ): if ( len(indices) == x.ndim == 2 and indices[0] == index[:] and not isinstance(indices[1], slice) ): x = tf.transpose(x) if isinstance(values_, int) or isinstance(values_, float): values_ = tf.fill(x.shape[-1:], values_) result = tf.tensor_scatter_nd_update(x, [[indices[-1]]], values_[None]) return type(self)(tf.transpose(result)) else: raise NotImplementedError # pragma: no cover elif isinstance(indices, tuple): if all(idx.dtype in [tf.int32, tf.int64] for idx in indices): indices = [ tf.cast(idx, tf.int64) if idx.dtype == tf.int32 else idx for idx in indices ] indices = tf.stack(indices, axis=-1) if isinstance(values_, int) or isinstance(values_, float): values_ = tf.fill((indices.shape[0],), values_) return type(self)(tf.tensor_scatter_nd_update(x, indices, values_)) else: raise ValueError # pragma: no cover @samedevice def arange( self: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None, ) -> TensorType: if step is None: step = 1 if stop is None: stop = start start = 0 return type(self)(tf.range(start, stop, step)) def cumsum(self: TensorType, axis: Optional[int] = None) -> TensorType: if axis is None: x = tf.reshape(self.raw, (-1,)) return type(self)(tf.cumsum(x, axis=0)) return type(self)(tf.cumsum(self.raw, axis=axis)) def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: if axis is None: axis = tuple(range(self.ndim)) if not isinstance(axis, Iterable): axis = (axis,) return type(self)(tf.reverse(self.raw, axis=axis)) def meshgrid( self: TensorType, *tensors: TensorType, indexing: str = "xy" ) -> Tuple[TensorType, ...]: tensors = unwrap_(*tensors) outputs = tf.meshgrid(self.raw, *tensors, indexing=indexing) return tuple(type(self)(out) for out in outputs) def pad( self: TensorType, paddings: Tuple[Tuple[int, int], ...], mode: str = "constant", value: float = 0, ) -> TensorType: if len(paddings) != self.ndim: raise ValueError("pad requires a tuple for each dimension") for p in paddings: if len(p) != 2: raise ValueError("pad requires a tuple for each dimension") if not (mode == "constant" or mode == "reflect"): raise ValueError("pad requires mode 'constant' or 'reflect'") if mode == "reflect": # PyTorch's pad has limited support for 'reflect' padding if self.ndim != 3 and self.ndim != 4: raise NotImplementedError # pragma: no cover k = self.ndim - 2 if paddings[:k] != ((0, 0),) * k: raise NotImplementedError # pragma: no cover return type(self)(tf.pad(self.raw, paddings, mode=mode, constant_values=value)) def isnan(self: TensorType) -> TensorType: return type(self)(tf.math.is_nan(self.raw)) def isinf(self: TensorType) -> TensorType: return type(self)(tf.math.is_inf(self.raw)) def crossentropy(self: TensorType, labels: TensorType) -> TensorType: if self.ndim != 2: raise ValueError("crossentropy only supported for 2D logits tensors") if self.shape[:1] != labels.shape: raise ValueError("labels must be 1D and must match the length of logits") return type(self)( tf.nn.sparse_softmax_cross_entropy_with_logits(labels.raw, self.raw) ) def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]: sign, logabsdet = tf.linalg.slogdet(self.raw) return type(self)(sign), type(self)(logabsdet) @overload def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., TensorType], has_aux: Literal[False] ) -> Callable[..., Tuple[TensorType, TensorType]]: ... @overload # noqa: F811 (waiting for pyflakes > 2.1.1) def _value_and_grad_fn( self: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True], ) -> Callable[..., Tuple[TensorType, Any, TensorType]]: ... def _value_and_grad_fn( # noqa: F811 (waiting for pyflakes > 2.1.1) self: TensorType, f: Callable, has_aux: bool = False ) -> Callable[..., Tuple]: def value_and_grad(x: TensorType, *args: Any, **kwargs: Any) -> Tuple: # using tf.identity to make x independent from possible other instances of x in args x_ = TensorFlowTensor(tf.identity(x.raw)) del x with tf.GradientTape() as tape: tape.watch(x_.raw) if has_aux: loss, aux = f(x_, *args, **kwargs) else: loss = f(x_, *args, **kwargs) grad = tape.gradient(loss.raw, x_.raw) grad = TensorFlowTensor(grad) assert grad.shape == x_.shape if has_aux: return loss, aux, grad else: return loss, grad return value_and_grad def sign(self: TensorType) -> TensorType: return type(self)(tf.sign(self.raw)) def sqrt(self: TensorType) -> TensorType: return type(self)(tf.sqrt(self.raw)) def tanh(self: TensorType) -> TensorType: return type(self)(tf.tanh(self.raw)) def float32(self: TensorType) -> TensorType: return self.astype(tf.float32) def float64(self: TensorType) -> TensorType: return self.astype(tf.float64) def where(self: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType: x, y = unwrap_(x, y) return type(self)(tf.where(self.raw, x, y)) @common_dtype def __lt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__lt__(unwrap1(other))) @common_dtype def __le__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__le__(unwrap1(other))) @common_dtype def __eq__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__eq__(unwrap1(other))) @common_dtype def __ne__(self: TensorType, other: TensorOrScalar) -> TensorType: # type: ignore return type(self)(self.raw.__ne__(unwrap1(other))) @common_dtype def __gt__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__gt__(unwrap1(other))) @common_dtype def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__ge__(unwrap1(other))) def __getitem__(self: TensorType, index: Any) -> TensorType: if isinstance(index, tuple): index = tuple(x.raw if isinstance(x, Tensor) else x for x in index) basic = all( x is None or x is Ellipsis or isinstance(x, int) or isinstance(x, slice) for x in index ) if not basic: # workaround for missing support for this in TensorFlow index = [tf.convert_to_tensor(x) for x in index] shapes = [tuple(x.shape) for x in index] shape = tuple(max(x) for x in zip(*shapes)) int64 = any(x.dtype == tf.int64 for x in index) for i in range(len(index)): t = index[i] if int64: t = tf.cast(t, tf.int64) # pragma: no cover assert t.ndim == len(shape) tiling = [] for b, k in zip(shape, t.shape): if k == 1: tiling.append(b) elif k == b: tiling.append(1) else: raise ValueError( # pragma: no cover f"{tuple(t.shape)} cannot be broadcasted to {shape}" ) index[i] = tf.tile(t, tiling) index = tf.stack(index, axis=-1) return type(self)(tf.gather_nd(self.raw, index)) elif ( isinstance(index, range) or isinstance(index, list) or isinstance(index, np.ndarray) ): return type(self)(tf.gather(self.raw, index)) elif isinstance(index, Tensor): if index.raw.dtype == tf.bool: return type(self)(self.raw.__getitem__(index.raw)) else: return type(self)(tf.gather(self.raw, index.raw)) return type(self)(self.raw.__getitem__(index)) def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: axis = batch_dims = axis % self.ndim if axis != self.ndim - 1: raise NotImplementedError( "take_along_axis is currently only supported for the last axis" ) return type(self)( tf.gather(self.raw, index.raw, axis=axis, batch_dims=batch_dims) ) def bool(self: TensorType) -> TensorType: return self.astype(tf.bool) eagerpy-0.30.0/eagerpy/types.py000066400000000000000000000011211410374365400164020ustar00rootroot00000000000000from typing import Union, Tuple, TYPE_CHECKING if TYPE_CHECKING: # for static analyzers import torch # noqa: F401 import tensorflow # noqa: F401 import jax # noqa: F401 import numpy # noqa: F401 Axes = Tuple[int, ...] AxisAxes = Union[int, Axes] Shape = Tuple[int, ...] ShapeOrScalar = Union[Shape, int] # tensorflow.Tensor, jax.numpy.ndarray and numpy.ndarray currently evaluate to Any # we can therefore only provide additional type information for torch.Tensor NativeTensor = Union[ "torch.Tensor", "tensorflow.Tensor", "jax.numpy.ndarray", "numpy.ndarray" ] eagerpy-0.30.0/eagerpy/utils.py000066400000000000000000000026251410374365400164100ustar00rootroot00000000000000from typing import overload from typing_extensions import Literal from .tensor import Tensor from .tensor import PyTorchTensor from .tensor import TensorFlowTensor from .tensor import JAXTensor from .tensor import NumPyTensor from . import modules @overload def get_dummy(framework: Literal["pytorch"]) -> PyTorchTensor: ... @overload def get_dummy(framework: Literal["tensorflow"]) -> TensorFlowTensor: ... @overload def get_dummy(framework: Literal["jax"]) -> JAXTensor: ... @overload def get_dummy(framework: Literal["numpy"]) -> NumPyTensor: ... @overload def get_dummy(framework: str) -> Tensor: ... def get_dummy(framework: str) -> Tensor: x: Tensor if framework == "pytorch": x = modules.torch.zeros(0) assert isinstance(x, PyTorchTensor) elif framework == "pytorch-gpu": x = modules.torch.zeros(0, device="cuda:0") # pragma: no cover assert isinstance(x, PyTorchTensor) # pragma: no cover elif framework == "tensorflow": x = modules.tensorflow.zeros(0) assert isinstance(x, TensorFlowTensor) elif framework == "jax": x = modules.jax.numpy.zeros(0) assert isinstance(x, JAXTensor) elif framework == "numpy": x = modules.numpy.zeros(0) assert isinstance(x, NumPyTensor) else: raise ValueError(f"unknown framework: {framework}") # pragma: no cover return x.float32() eagerpy-0.30.0/pydocmd.yml000066400000000000000000000074171410374365400154300ustar00rootroot00000000000000site_name: "" generate: - norms.md: - eagerpy.norms: - eagerpy.norms.l0 - eagerpy.norms.l1 - eagerpy.norms.l2 - eagerpy.norms.linf - eagerpy.norms.lp - tensor.md: - eagerpy.PyTorchTensor - eagerpy.TensorFlowTensor - eagerpy.JAXTensor - eagerpy.NumPyTensor - eagerpy.Tensor: - eagerpy.Tensor.sign - eagerpy.Tensor.sqrt - eagerpy.Tensor.tanh - eagerpy.Tensor.float32 - eagerpy.Tensor.where - eagerpy.Tensor.matmul - eagerpy.Tensor.ndim - eagerpy.Tensor.numpy - eagerpy.Tensor.item - eagerpy.Tensor.shape - eagerpy.Tensor.reshape - eagerpy.Tensor.take_along_axis - eagerpy.Tensor.astype - eagerpy.Tensor.clip - eagerpy.Tensor.square - eagerpy.Tensor.arctanh - eagerpy.Tensor.sum - eagerpy.Tensor.prod - eagerpy.Tensor.mean - eagerpy.Tensor.min - eagerpy.Tensor.max - eagerpy.Tensor.minimum - eagerpy.Tensor.maximum - eagerpy.Tensor.argmin - eagerpy.Tensor.argmax - eagerpy.Tensor.argsort - eagerpy.Tensor.uniform - eagerpy.Tensor.normal - eagerpy.Tensor.ones - eagerpy.Tensor.zeros - eagerpy.Tensor.ones_like - eagerpy.Tensor.zeros_like - eagerpy.Tensor.full_like - eagerpy.Tensor.onehot_like - eagerpy.Tensor.from_numpy - eagerpy.Tensor.transpose - eagerpy.Tensor.bool - eagerpy.Tensor.all - eagerpy.Tensor.any - eagerpy.Tensor.logical_and - eagerpy.Tensor.logical_or - eagerpy.Tensor.logical_not - eagerpy.Tensor.exp - eagerpy.Tensor.log - eagerpy.Tensor.log2 - eagerpy.Tensor.log10 - eagerpy.Tensor.log1p - eagerpy.Tensor.tile - eagerpy.Tensor.softmax - eagerpy.Tensor.log_softmax - eagerpy.Tensor.squeeze - eagerpy.Tensor.expand_dims - eagerpy.Tensor.full - eagerpy.Tensor.index_update - eagerpy.Tensor.arange - eagerpy.Tensor.cumsum - eagerpy.Tensor.flip - eagerpy.Tensor.meshgrid - eagerpy.Tensor.pad - eagerpy.Tensor.isnan - eagerpy.Tensor.isinf - eagerpy.Tensor.crossentropy - eagerpy.Tensor.T - eagerpy.Tensor.abs - eagerpy.Tensor.pow - eagerpy.Tensor.value_and_grad - eagerpy.Tensor.value_aux_and_grad - eagerpy.Tensor.flatten - eagerpy.Tensor.norms.l0 - eagerpy.Tensor.norms.l1 - eagerpy.Tensor.norms.l2 - eagerpy.Tensor.norms.linf - eagerpy.Tensor.norms.lp - eagerpy.Tensor.raw - eagerpy.Tensor.dtype - eagerpy.Tensor.__init__ - eagerpy.Tensor.__repr__ - eagerpy.Tensor.__format__ - eagerpy.Tensor.__getitem__ - eagerpy.Tensor.__iter__ - eagerpy.Tensor.__bool__ - eagerpy.Tensor.__len__ - eagerpy.Tensor.__abs__ - eagerpy.Tensor.__neg__ - eagerpy.Tensor.__add__ - eagerpy.Tensor.__radd__ - eagerpy.Tensor.__sub__ - eagerpy.Tensor.__rsub__ - eagerpy.Tensor.__mul__ - eagerpy.Tensor.__rmul__ - eagerpy.Tensor.__truediv__ - eagerpy.Tensor.__rtruediv__ - eagerpy.Tensor.__floordiv__ - eagerpy.Tensor.__rfloordiv__ - eagerpy.Tensor.__mod__ - eagerpy.Tensor.__lt__ - eagerpy.Tensor.__le__ - eagerpy.Tensor.__eq__ - eagerpy.Tensor.__ne__ - eagerpy.Tensor.__gt__ - eagerpy.Tensor.__ge__ - eagerpy.Tensor.__pow__ # Required by Pydoc-Markdown, but irrelevant to us. pages: [] gens_dir: ./docs/api # Render headers are Markdown tags rather than HTML headers: markdown eagerpy-0.30.0/pyproject.toml000066400000000000000000000004161410374365400161520ustar00rootroot00000000000000[tool.black] line-length = 88 target-version = ['py36', 'py37', 'py38'] include = '\.pyi?$' exclude = ''' /( \.eggs | \.git | \.hg | \.mypy_cache | \.tox | \.venv | _build | buck-out | build | dist # specific to EagerPy | .pytest_cache )/ ''' eagerpy-0.30.0/requirements-dev.txt000066400000000000000000000002441410374365400172750ustar00rootroot00000000000000flake8>=3.7.9 black>=19.10b0 pytest>=5.3.2 pytest-cov>=2.8.1 coverage>=5.0.3 codecov>=2.0.15 coveralls>=1.10.0 mypy>=0.761 pre-commit>=1.21.0 pydoc-markdown==2.0.5 eagerpy-0.30.0/requirements.txt000066400000000000000000000001221410374365400165140ustar00rootroot00000000000000-r requirements-dev.txt torch==1.4.0 jaxlib==0.1.37 jax==0.1.57 tensorflow==2.5.0 eagerpy-0.30.0/setup.cfg000066400000000000000000000020641410374365400150600ustar00rootroot00000000000000[flake8] ignore = E203, E266, E501, W503 max-line-length = 80 max-complexity = 18 select = B,C,E,F,W,T4,B9 [mypy] python_version = 3.6 warn_unused_ignores = True warn_unused_configs = True warn_return_any = True warn_redundant_casts = True warn_unreachable = True ignore_missing_imports = False disallow_any_unimported = True disallow_untyped_calls = True no_implicit_optional = True disallow_untyped_defs = True [mypy-numpy.*] ignore_missing_imports = True [mypy-jax.*] ignore_missing_imports = True [mypy-tensorflow] ignore_missing_imports = True [mypy-pytest] ignore_missing_imports = True [mypy-setuptools] ignore_missing_imports = True [tool:pytest] filterwarnings = ignore::DeprecationWarning ignore::PendingDeprecationWarning # produced by TensorFlow: ignore:.*can't resolve package from __spec__ or __package__.*:ImportWarning [coverage:report] exclude_lines = # see: http://coverage.readthedocs.io/en/latest/config.html # Have to re-enable the standard pragma pragma: no cover @abstractmethod @overload TYPE_CHECKING eagerpy-0.30.0/setup.py000066400000000000000000000031621410374365400147510ustar00rootroot00000000000000from setuptools import setup from setuptools import find_packages from os.path import join, dirname with open(join(dirname(__file__), "eagerpy/VERSION")) as f: version = f.read().strip() try: # obtain long description from README readme_path = join(dirname(__file__), "README.rst") with open(readme_path, encoding="utf-8") as f: README = f.read() # remove raw html not supported by PyPI README = "\n".join(README.split("\n")[3:]) except IOError: README = "" install_requires = ["numpy", "typing-extensions>=3.7.4.1"] tests_require = ["pytest>=5.3.5", "pytest-cov>=2.8.1"] setup( name="eagerpy", version=version, description="EagerPy is a thin wrapper around PyTorch, TensorFlow Eager, JAX and NumPy that unifies their interface and thus allows writing code that works natively across all of them.", long_description=README, long_description_content_type="text/x-rst", classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], keywords="", author="Jonas Rauber", author_email="jonas.rauber@bethgelab.org", url="https://github.com/jonasrauber/eagerpy", license="MIT License", packages=find_packages(), include_package_data=True, zip_safe=False, install_requires=install_requires, extras_require={"testing": tests_require}, ) eagerpy-0.30.0/tests/000077500000000000000000000000001410374365400143775ustar00rootroot00000000000000eagerpy-0.30.0/tests/conftest.py000066400000000000000000000017641410374365400166060ustar00rootroot00000000000000from typing import Any, Optional import pytest import eagerpy as ep def pytest_addoption(parser: Any) -> None: parser.addoption("--backend") @pytest.fixture(scope="session") def dummy(request: Any) -> ep.Tensor: backend: Optional[str] = request.config.option.backend if backend is None: pytest.skip() assert False return ep.utils.get_dummy(backend) @pytest.fixture(scope="session") def t1(dummy: ep.Tensor) -> ep.Tensor: return ep.arange(dummy, 5).float32() @pytest.fixture(scope="session") def t1int(dummy: ep.Tensor) -> ep.Tensor: return ep.arange(dummy, 5) @pytest.fixture(scope="session") def t2(dummy: ep.Tensor) -> ep.Tensor: return ep.arange(dummy, 7, 17, 2).float32() @pytest.fixture(scope="session") def t2int(dummy: ep.Tensor) -> ep.Tensor: return ep.arange(dummy, 7, 17, 2) @pytest.fixture(scope="session", params=["t1", "t2"]) def t(request: Any, t1: ep.Tensor, t2: ep.Tensor) -> ep.Tensor: return {"t1": t1, "t2": t2}[request.param] eagerpy-0.30.0/tests/test_lib.py000066400000000000000000000004521410374365400165570ustar00rootroot00000000000000import pytest import eagerpy as ep @pytest.mark.parametrize("axis", [0, 1, -1]) def test_kl_div_with_logits(dummy: ep.Tensor, axis: int) -> None: logits_p = logits_q = ep.arange(dummy, 12).float32().reshape((3, 4)) assert (ep.kl_div_with_logits(logits_p, logits_q, axis=axis) == 0).all() eagerpy-0.30.0/tests/test_main.py000066400000000000000000001101611410374365400167340ustar00rootroot00000000000000from typing import Callable, Dict, Any, Tuple, Union, Optional, cast import pytest import functools import itertools import numpy as np import eagerpy as ep from eagerpy import Tensor from eagerpy.types import Shape, AxisAxes # make sure there are no undecorated tests in the "special tests" section below # -> /\n\ndef test_ # make sure the undecorated tests in the "normal tests" section all contain # assertions and do not return something # -> /\n return ############################################################################### # normal tests # - no decorator # - assertions ############################################################################### def test_astensor_raw(t: Tensor) -> None: assert (ep.astensor(t.raw) == t).all() def test_astensor_tensor(t: Tensor) -> None: assert (ep.astensor(t) == t).all() def test_astensor_restore_raw(t: Tensor) -> None: r = t.raw y, restore_type = ep.astensor_(r) assert (y == t).all() assert type(restore_type(y)) == type(r) y = y + 1 assert type(restore_type(y)) == type(r) def test_astensor_restore_tensor(t: Tensor) -> None: r = t y, restore_type = ep.astensor_(r) assert (y == t).all() assert type(restore_type(y)) == type(r) y = y + 1 assert type(restore_type(y)) == type(r) def test_astensors_raw(t: Tensor) -> None: ts = (t, t + 1, t + 2) rs = tuple(t.raw for t in ts) ys = ep.astensors(*rs) assert isinstance(ys, tuple) assert len(ts) == len(ys) for ti, yi in zip(ts, ys): assert (ti == yi).all() def test_astensors_tensor(t: Tensor) -> None: ts = (t, t + 1, t + 2) ys = ep.astensors(*ts) assert isinstance(ys, tuple) assert len(ts) == len(ys) for ti, yi in zip(ts, ys): assert (ti == yi).all() def test_astensors_raw_restore(t: Tensor) -> None: ts = (t, t + 1, t + 2) rs = tuple(t.raw for t in ts) ys, restore_type = ep.astensors_(*rs) assert isinstance(ys, tuple) assert len(ts) == len(ys) for ti, yi in zip(ts, ys): assert (ti == yi).all() ys = tuple(y + 1 for y in ys) xs = restore_type(*ys) assert isinstance(xs, tuple) assert len(xs) == len(ys) for xi, ri in zip(xs, rs): assert type(xi) == type(ri) x0 = restore_type(ys[0]) assert not isinstance(x0, tuple) def test_astensors_tensors_restore(t: Tensor) -> None: ts = (t, t + 1, t + 2) rs = ts ys, restore_type = ep.astensors_(*rs) assert isinstance(ys, tuple) assert len(ts) == len(ys) for ti, yi in zip(ts, ys): assert (ti == yi).all() ys = tuple(y + 1 for y in ys) xs = restore_type(*ys) assert isinstance(xs, tuple) assert len(xs) == len(ys) for xi, ri in zip(xs, rs): assert type(xi) == type(ri) x0 = restore_type(ys[0]) assert not isinstance(x0, tuple) # type: ignore def test_module() -> None: assert ep.istensor(ep.numpy.tanh([3, 5])) assert not ep.istensor(ep.numpy.tanh(3)) def test_module_dir() -> None: assert "zeros" in dir(ep.numpy) def test_repr(t: Tensor) -> None: assert not repr(t).startswith("<") t = ep.zeros(t, (10, 10)) assert not repr(t).startswith("<") assert len(repr(t).split("\n")) > 1 def test_logical_or_manual(t: Tensor) -> None: assert (ep.logical_or(t < 3, ep.zeros_like(t).bool()) == (t < 3)).all() def test_logical_not_manual(t: Tensor) -> None: assert (ep.logical_not(t > 3) == (t <= 3)).all() def test_softmax_manual(t: Tensor) -> None: s = ep.softmax(t) assert (s >= 0).all() assert (s <= 1).all() np.testing.assert_allclose(s.sum().numpy(), 1.0, rtol=1e-6) def test_log_softmax_manual(t: Tensor) -> None: np.testing.assert_allclose( ep.log_softmax(t).exp().numpy(), ep.softmax(t).numpy(), rtol=1e-6 ) def test_value_and_grad_fn(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() def f(x: ep.Tensor) -> ep.Tensor: return x.square().sum() vgf = ep.value_and_grad_fn(dummy, f) t = ep.arange(dummy, 8).float32().reshape((2, 4)) v, g = vgf(t) assert v.item() == 140 assert (g == 2 * t).all() def test_value_and_grad_fn_with_aux(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() def f(x: Tensor) -> Tuple[Tensor, Tensor]: x = x.square() return x.sum(), x vgf = ep.value_and_grad_fn(dummy, f, has_aux=True) t = ep.arange(dummy, 8).float32().reshape((2, 4)) v, aux, g = vgf(t) assert v.item() == 140 assert (aux == t.square()).all() assert (g == 2 * t).all() def test_value_and_grad(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() def f(x: Tensor) -> Tensor: return x.square().sum() t = ep.arange(dummy, 8).float32().reshape((2, 4)) v, g = ep.value_and_grad(f, t) assert v.item() == 140 assert (g == 2 * t).all() def test_value_aux_and_grad(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() def f(x: Tensor) -> Tuple[Tensor, Tensor]: x = x.square() return x.sum(), x t = ep.arange(dummy, 8).float32().reshape((2, 4)) v, aux, g = ep.value_aux_and_grad(f, t) assert v.item() == 140 assert (aux == t.square()).all() assert (g == 2 * t).all() def test_value_aux_and_grad_multiple_aux(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() def f(x: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: x = x.square() return x.sum(), (x, x + 1) t = ep.arange(dummy, 8).float32().reshape((2, 4)) v, (aux0, aux1), g = ep.value_aux_and_grad(f, t) assert v.item() == 140 assert (aux0 == t.square()).all() assert (aux1 == t.square() + 1).all() assert (g == 2 * t).all() def test_value_and_grad_multiple_args(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() def f(x: Tensor, y: Tensor) -> Tensor: return (x * y).sum() t = ep.arange(dummy, 8).float32().reshape((2, 4)) v, g = ep.value_and_grad(f, t, t) assert v.item() == 140 assert (g == t).all() def test_logical_and_manual(t: Tensor) -> None: assert (ep.logical_and(t < 3, ep.ones_like(t).bool()) == (t < 3)).all() def test_transpose_1d(dummy: Tensor) -> None: t = ep.arange(dummy, 8).float32() assert (ep.transpose(t) == t).all() def test_onehot_like_raises(dummy: Tensor) -> None: t = ep.arange(dummy, 18).float32().reshape((6, 3)) indices = ep.arange(t, 6) // 2 ep.onehot_like(t, indices) t = ep.arange(dummy, 90).float32().reshape((6, 3, 5)) indices = ep.arange(t, 6) // 2 with pytest.raises(ValueError): ep.onehot_like(t, indices) t = ep.arange(dummy, 18).float32().reshape((6, 3)) indices = ep.arange(t, 6).reshape((6, 1)) // 2 with pytest.raises(ValueError): ep.onehot_like(t, indices) t = ep.arange(dummy, 18).float32().reshape((6, 3)) indices = ep.arange(t, 5) // 2 with pytest.raises(ValueError): ep.onehot_like(t, indices) def test_tile_raises(t: Tensor) -> None: ep.tile(t, (3,) * t.ndim) with pytest.raises(ValueError): ep.tile(t, (3,) * (t.ndim - 1)) def test_pad_raises(dummy: Tensor) -> None: t = ep.arange(dummy, 120).reshape((2, 3, 4, 5)).float32() ep.pad(t, ((0, 0), (0, 0), (2, 3), (1, 2)), mode="constant") with pytest.raises(ValueError): ep.pad(t, ((0, 0), (2, 3), (1, 2)), mode="constant") with pytest.raises(ValueError): ep.pad( t, ((0, 0), (0, 0, 1, 2), (2, 3), (1, 2)), mode="constant" # type: ignore ) with pytest.raises(ValueError): ep.pad(t, ((0, 0), (0, 0), (2, 3), (1, 2)), mode="foo") def test_mean_bool(t: Tensor) -> None: with pytest.raises(ValueError): ep.mean(t != 0) def test_mean_int(t: Tensor) -> None: with pytest.raises(ValueError): ep.mean(ep.arange(t, 5)) @pytest.mark.parametrize("f", [ep.logical_and, ep.logical_or]) def test_logical_and_nonboolean( t: Tensor, f: Callable[[Tensor, Tensor], Tensor] ) -> None: t = t.float32() f(t > 1, t > 1) with pytest.raises(ValueError): f(t, t > 1) with pytest.raises(ValueError): f(t > 1, t) with pytest.raises(ValueError): f(t, t) def test_crossentropy_raises(dummy: Tensor) -> None: t = ep.arange(dummy, 50).reshape((10, 5)).float32() t = t / t.max() ep.crossentropy(t, t.argmax(axis=-1)) t = ep.arange(dummy, 150).reshape((10, 5, 3)).float32() t = t / t.max() with pytest.raises(ValueError): ep.crossentropy(t, t.argmax(axis=-1)) t = ep.arange(dummy, 50).reshape((10, 5)).float32() t = t / t.max() with pytest.raises(ValueError): ep.crossentropy(t, t.argmax(axis=-1)[:8]) def test_matmul_raise(dummy: Tensor) -> None: t = ep.arange(dummy, 8).float32().reshape((2, 4)) ep.matmul(t, t.T) with pytest.raises(ValueError): ep.matmul(t, t[0]) with pytest.raises(ValueError): ep.matmul(t[0], t) with pytest.raises(ValueError): ep.matmul(t[0], t[0]) def test_take_along_axis_2d_first_raises(dummy: Tensor) -> None: t = ep.arange(dummy, 32).float32().reshape((8, 4)) indices = ep.arange(t, t.shape[-1]) % t.shape[0] with pytest.raises(NotImplementedError): ep.take_along_axis(t, indices[ep.newaxis], axis=0) def test_norms_class() -> None: assert ep.Tensor.norms is not None def test_numpy_readonly(t: Tensor) -> None: a = t.numpy() assert a.flags.writeable is False with pytest.raises(ValueError, match="read-only"): a[:] += 1 def test_numpy_inplace(t: Tensor) -> None: copy = t + 0 a = t.numpy().copy() a[:] += 1 assert (t == copy).all() def test_iter_list_stack(t: Tensor) -> None: t2 = ep.stack(list(iter(t))) assert t.shape == t2.shape assert (t == t2).all() def test_list_stack(t: Tensor) -> None: t2 = ep.stack(list(t)) assert t.shape == t2.shape assert (t == t2).all() def test_iter_next(t: Tensor) -> None: assert isinstance(next(iter(t)), Tensor) def test_flatten(dummy: Tensor) -> None: t = ep.ones(dummy, (16, 3, 32, 32)) assert ep.flatten(t).shape == (16 * 3 * 32 * 32,) assert ep.flatten(t, start=1).shape == (16, 3 * 32 * 32) assert ep.flatten(t, start=2).shape == (16, 3, 32 * 32) assert ep.flatten(t, start=3).shape == (16, 3, 32, 32) assert ep.flatten(t, end=-2).shape == (16 * 3 * 32, 32) assert ep.flatten(t, end=-3).shape == (16 * 3, 32, 32) assert ep.flatten(t, end=-4).shape == (16, 3, 32, 32) assert ep.flatten(t, start=1, end=-2).shape == (16, 3 * 32, 32) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) def test_squeeze_not_one(dummy: Tensor, axis: Optional[AxisAxes]) -> None: t = ep.zeros(dummy, (3, 4, 5)) if axis is None: t.squeeze(axis=axis) else: with pytest.raises(Exception): # squeezing specifc axis should fail if they are not 1 t.squeeze(axis=axis) ############################################################################### # special tests # - decorated with compare_* # - return values ############################################################################### def get_numpy_kwargs(kwargs: Any) -> Dict: return { k: ep.astensor(t.numpy()) if ep.istensor(t) else t for k, t in kwargs.items() } def compare_all(f: Callable[..., Tensor]) -> Callable[..., None]: """A decorator to simplify writing test functions""" @functools.wraps(f) def test_fn(*args: Any, **kwargs: Any) -> None: assert len(args) == 0 nkwargs = get_numpy_kwargs(kwargs) t = f(*args, **kwargs) n = f(*args, **nkwargs) t = t.numpy() n = n.numpy() assert t.shape == n.shape assert (t == n).all() return test_fn def compare_allclose(*args: Any, rtol: float = 1e-07, atol: float = 0) -> Callable: """A decorator to simplify writing test functions""" def compare_allclose_inner(f: Callable[..., Tensor]) -> Callable[..., None]: @functools.wraps(f) def test_fn(*args: Any, **kwargs: Any) -> None: assert len(args) == 0 nkwargs = get_numpy_kwargs(kwargs) t = f(*args, **kwargs) n = f(*args, **nkwargs) t = t.numpy() n = n.numpy() assert t.shape == n.shape np.testing.assert_allclose(t, n, rtol=rtol, atol=atol) return test_fn if len(args) == 1 and callable(args[0]): # decorator applied without parenthesis return compare_allclose_inner(args[0]) return compare_allclose_inner def compare_equal( f: Callable[..., Union[Tensor, int, float, bool, Shape]] ) -> Callable[..., None]: """A decorator to simplify writing test functions""" @functools.wraps(f) def test_fn(*args: Any, **kwargs: Any) -> None: assert len(args) == 0 nkwargs = get_numpy_kwargs(kwargs) t = f(*args, **kwargs) n = f(*args, **nkwargs) assert isinstance(t, type(n)) assert t == n return test_fn @compare_equal def test_format(dummy: Tensor) -> bool: t = ep.arange(dummy, 5).sum() return f"{t:.1f}" == "10.0" @compare_equal def test_item(t: Tensor) -> float: t = t.sum() return t.item() @compare_equal def test_len(t: Tensor) -> int: return len(t) @compare_equal def test_scalar_bool(t: Tensor) -> bool: return bool(ep.sum(t) == 0) @compare_all def test_neg(t: Tensor) -> Tensor: return -t @compare_all def test_square(t: Tensor) -> Tensor: return ep.square(t) @compare_allclose def test_pow(t: Tensor) -> Tensor: return ep.pow(t, 3) @compare_allclose def test_pow_float(t: Tensor) -> Tensor: return ep.pow(t, 2.5) @compare_allclose def test_pow_op(t: Tensor) -> Tensor: return t ** 3 @compare_allclose def test_pow_tensor(t: Tensor) -> Tensor: return ep.pow(t, (t + 0.5)) @compare_allclose def test_pow_op_tensor(t: Tensor) -> Tensor: return t ** (t + 0.5) @compare_all def test_add(t1: Tensor, t2: Tensor) -> Tensor: return t1 + t2 @compare_all def test_add_scalar(t: Tensor) -> Tensor: return t + 3 @compare_all def test_radd_scalar(t: Tensor) -> Tensor: return 3 + t @compare_all def test_sub(t1: Tensor, t2: Tensor) -> Tensor: return t1 - t2 @compare_all def test_sub_scalar(t: Tensor) -> Tensor: return t - 3 @compare_all def test_rsub_scalar(t: Tensor) -> Tensor: return 3 - t @compare_all def test_mul(t1: Tensor, t2: Tensor) -> Tensor: return t1 * t2 @compare_all def test_mul_scalar(t: Tensor) -> Tensor: return t * 3 @compare_all def test_rmul_scalar(t: Tensor) -> Tensor: return 3 * t @compare_allclose def test_truediv(t1: Tensor, t2: Tensor) -> Tensor: return t1 / t2 @compare_allclose(rtol=1e-6) def test_truediv_scalar(t: Tensor) -> Tensor: return t / 3 @compare_allclose def test_rtruediv_scalar(t: Tensor) -> Tensor: return 3 / (abs(t) + 3e-8) @compare_allclose def test_floordiv(t1: Tensor, t2: Tensor) -> Tensor: return t1 // t2 @compare_allclose(rtol=1e-6) def test_floordiv_scalar(t: Tensor) -> Tensor: return t // 3 @compare_allclose def test_rfloordiv_scalar(t: Tensor) -> Tensor: return 3 // (abs(t) + 1e-8) @compare_all def test_mod(t1: Tensor, t2: Tensor) -> Tensor: return t1 % (abs(t2) + 1) @compare_all def test_mod_scalar(t: Tensor) -> Tensor: return t % 3 @compare_all def test_getitem(t: Tensor) -> Tensor: return t[2] @compare_all def test_getitem_tuple(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 8).float32().reshape((2, 4)) return t[1, 3] @compare_all def test_getitem_newaxis(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 8).float32() return t[ep.newaxis] @compare_all def test_getitem_ellipsis_newaxis(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 8).float32() return t[..., ep.newaxis] @compare_all def test_getitem_tensor(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32() indices = ep.arange(t, 3, 10, 2) return t[indices] @compare_all def test_getitem_range(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32() indices = range(3, 10, 2) return t[indices] @compare_all def test_getitem_list(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32() indices = list(range(3, 10, 2)) return t[indices] @compare_all def test_getitem_ndarray(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32() indices = np.arange(3, 10, 2) return t[indices] @compare_all def test_getitem_tuple_tensors(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) rows = ep.arange(t, len(t)) indices = ep.arange(t, len(t)) % t.shape[1] return t[rows, indices] @compare_all def test_getitem_tuple_tensors_full(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) rows = ep.arange(t, len(t))[:, np.newaxis].tile((1, t.shape[-1])) cols = t.argsort(axis=-1) return t[rows, cols] @compare_all def test_getitem_tuple_tensors_full_broadcast(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) rows = ep.arange(t, len(t))[:, np.newaxis] cols = t.argsort(axis=-1) return t[rows, cols] @compare_all def test_getitem_tuple_range_tensor(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) rows = range(len(t)) indices = ep.arange(t, len(t)) % t.shape[1] return t[rows, indices] @compare_all def test_getitem_tuple_range_range(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 36).float32().reshape((6, 6)) rows = cols = range(len(t)) return t[rows, cols] @compare_all def test_getitem_tuple_list_tensor(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) rows = list(range(len(t))) indices = ep.arange(t, len(t)) % t.shape[1] return t[rows, indices] @compare_all def test_getitem_slice(t: Tensor) -> Tensor: return t[1:3] @compare_all def test_getitem_slice_slice(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((4, 8)) return t[:, :3] @compare_all def test_getitem_boolean_tensor(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((4, 8)) indices = ep.arange(t, 4) <= 2 return t[indices] @compare_all def test_take_along_axis_2d(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) indices = ep.arange(t, len(t)) % t.shape[-1] return ep.take_along_axis(t, indices[..., ep.newaxis], axis=-1) @compare_all def test_take_along_axis_3d(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 64).float32().reshape((2, 8, 4)) indices = ep.arange(t, 2 * 8).reshape((2, 8, 1)) % t.shape[-1] return ep.take_along_axis(t, indices, axis=-1) @compare_all def test_sqrt(t: Tensor) -> Tensor: return ep.sqrt(t) @compare_equal def test_shape(t: Tensor) -> Shape: return t.shape @compare_all def test_reshape(t: Tensor) -> Tensor: shape = (1,) + t.shape + (1,) return ep.reshape(t, shape) @compare_all def test_reshape_minus_1(t: Tensor) -> Tensor: return ep.reshape(t, -1) @compare_all def test_reshape_int(t: Tensor) -> Tensor: n = 1 for k in t.shape: n *= k return ep.reshape(t, n) @compare_all def test_clip(t: Tensor) -> Tensor: return ep.clip(t, 2, 3.5) @compare_all def test_sign(t: Tensor) -> Tensor: return ep.sign(t) @compare_all def test_sum(t: Tensor) -> Tensor: return ep.sum(t) @compare_all def test_sum_axis(t: Tensor) -> Tensor: return ep.sum(t, axis=0) @compare_all def test_sum_axes(dummy: Tensor) -> Tensor: t = ep.ones(dummy, 30).float32().reshape((3, 5, 2)) return ep.sum(t, axis=(0, 1)) @compare_all def test_sum_keepdims(t: Tensor) -> Tensor: return ep.sum(t, axis=0, keepdims=True) @compare_all def test_sum_none_keepdims(t: Tensor) -> Tensor: return ep.sum(t, axis=None, keepdims=True) @compare_all def test_sum_bool(t: Tensor) -> Tensor: return ep.sum(t != 0) @compare_all def test_sum_int(t: Tensor) -> Tensor: return ep.sum(ep.arange(t, 5)) @compare_all def test_prod(t: Tensor) -> Tensor: return ep.prod(t) @compare_all def test_prod_axis(t: Tensor) -> Tensor: return ep.prod(t, axis=0) @compare_all def test_prod_axes(dummy: Tensor) -> Tensor: t = ep.ones(dummy, 30).float32().reshape((3, 5, 2)) return ep.prod(t, axis=(0, 1)) @compare_all def test_prod_keepdims(t: Tensor) -> Tensor: return ep.prod(t, axis=0, keepdims=True) @compare_all def test_prod_none_keepdims(t: Tensor) -> Tensor: return ep.prod(t, axis=None, keepdims=True) @compare_all def test_prod_bool(t: Tensor) -> Tensor: return ep.prod(t != 0) @compare_all def test_prod_int(t: Tensor) -> Tensor: return ep.prod(ep.arange(t, 5)) @compare_all def test_mean(t: Tensor) -> Tensor: return ep.mean(t) @compare_all def test_mean_axis(t: Tensor) -> Tensor: return ep.mean(t, axis=0) @compare_all def test_mean_axes(dummy: Tensor) -> Tensor: t = ep.ones(dummy, 30).float32().reshape((3, 5, 2)) return ep.mean(t, axis=(0, 1)) @compare_all def test_mean_keepdims(t: Tensor) -> Tensor: return ep.mean(t, axis=0, keepdims=True) @compare_all def test_mean_none_keepdims(t: Tensor) -> Tensor: return ep.mean(t, axis=None, keepdims=True) @compare_all def test_all(t: Tensor) -> Tensor: return ep.all(t > 3) @compare_all def test_all_axis(t: Tensor) -> Tensor: return ep.all(t > 3, axis=0) @compare_all def test_all_axes(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 30).float32().reshape((3, 5, 2)) return ep.all(t > 3, axis=(0, 1)) @compare_all def test_all_keepdims(t: Tensor) -> Tensor: return ep.all(t > 3, axis=0, keepdims=True) @compare_all def test_all_none_keepdims(t: Tensor) -> Tensor: return ep.all(t > 3, axis=None, keepdims=True) @compare_all def test_any(t: Tensor) -> Tensor: return ep.any(t > 3) @compare_all def test_any_axis(t: Tensor) -> Tensor: return ep.any(t > 3, axis=0) @compare_all def test_any_axes(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 30).float32().reshape((3, 5, 2)) return ep.any(t > 3, axis=(0, 1)) @compare_all def test_any_keepdims(t: Tensor) -> Tensor: return ep.any(t > 3, axis=0, keepdims=True) @compare_all def test_any_none_keepdims(t: Tensor) -> Tensor: return ep.any(t > 3, axis=None, keepdims=True) @compare_all def test_min(t: Tensor) -> Tensor: return ep.min(t) @compare_all def test_min_axis(t: Tensor) -> Tensor: return ep.min(t, axis=0) @compare_all def test_min_axes(dummy: Tensor) -> Tensor: t = ep.ones(dummy, 30).float32().reshape((3, 5, 2)) return ep.min(t, axis=(0, 1)) @compare_all def test_min_keepdims(t: Tensor) -> Tensor: return ep.min(t, axis=0, keepdims=True) @compare_all def test_min_none_keepdims(t: Tensor) -> Tensor: return ep.min(t, axis=None, keepdims=True) @compare_all def test_max(t: Tensor) -> Tensor: return ep.max(t) @compare_all def test_max_axis(t: Tensor) -> Tensor: return ep.max(t, axis=0) @compare_all def test_max_axes(dummy: Tensor) -> Tensor: t = ep.ones(dummy, 30).float32().reshape((3, 5, 2)) return ep.max(t, axis=(0, 1)) @compare_all def test_max_keepdims(t: Tensor) -> Tensor: return ep.max(t, axis=0, keepdims=True) @compare_all def test_max_none_keepdims(t: Tensor) -> Tensor: return ep.max(t, axis=None, keepdims=True) @compare_allclose(rtol=1e-6) def test_exp(t: Tensor) -> Tensor: return ep.exp(t) @compare_allclose def test_log(t: Tensor) -> Tensor: return ep.log(t.maximum(1e-8)) @compare_allclose def test_log2(t: Tensor) -> Tensor: return ep.log2(t.maximum(1e-8)) @compare_allclose def test_log10(t: Tensor) -> Tensor: return ep.log10(t.maximum(1e-8)) @compare_allclose def test_log1p(t: Tensor) -> Tensor: return ep.log1p(t) @compare_allclose(rtol=1e-6) def test_tanh(t: Tensor) -> Tensor: return ep.tanh(t) @compare_allclose(rtol=1e-6) def test_arctanh(t: Tensor) -> Tensor: return ep.arctanh((t - t.mean()) / t.max()) @compare_all def test_abs_op(t: Tensor) -> Tensor: return abs(t) @compare_all def test_abs(t: Tensor) -> Tensor: return ep.abs(t) @compare_all def test_minimum(t1: Tensor, t2: Tensor) -> Tensor: return ep.minimum(t1, t2) @compare_all def test_minimum_scalar(t: Tensor) -> Tensor: return ep.minimum(t, 3) @compare_all def test_rminimum_scalar(t: Tensor) -> Tensor: return ep.minimum(3, t) @compare_all def test_maximum(t1: Tensor, t2: Tensor) -> Tensor: return ep.maximum(t1, t2) @compare_all def test_maximum_scalar(t: Tensor) -> Tensor: return ep.maximum(t, 3) @compare_all def test_rmaximum_scalar(t: Tensor) -> Tensor: return ep.maximum(3, t) @compare_all def test_argmin(t: Tensor) -> Tensor: return ep.argmin(t) @compare_all def test_argmin_axis(t: Tensor) -> Tensor: return ep.argmin(t, axis=0) @compare_all def test_argmax(t: Tensor) -> Tensor: return ep.argmax(t) @compare_all def test_argmax_axis(t: Tensor) -> Tensor: return ep.argmax(t, axis=0) @compare_all def test_logical_and(t: Tensor) -> Tensor: return ep.logical_and(t < 3, t > 1) @compare_all def test_logical_and_scalar(t: Tensor) -> Tensor: return ep.logical_and(True, t < 3) @compare_all def test_logical_or(t: Tensor) -> Tensor: return ep.logical_or(t > 3, t < 1) @compare_all def test_logical_or_scalar(t: Tensor) -> Tensor: return ep.logical_or(True, t < 1) @compare_all def test_logical_not(t: Tensor) -> Tensor: return ep.logical_not(t > 3) @compare_all def test_isnan_false(t: Tensor) -> Tensor: return ep.isnan(t) @compare_all def test_isnan_true(t: Tensor) -> Tensor: return ep.isnan(t + ep.nan) @compare_all def test_isinf(t: Tensor) -> Tensor: return ep.isinf(t) @compare_all def test_isinf_posinf(t: Tensor) -> Tensor: return ep.isinf(t + ep.inf) @compare_all def test_isinf_neginf(t: Tensor) -> Tensor: return ep.isinf(t - ep.inf) @compare_all def test_zeros_like(t: Tensor) -> Tensor: return ep.zeros_like(t) @compare_all def test_ones_like(t: Tensor) -> Tensor: return ep.ones_like(t) @compare_all def test_full_like(t: Tensor) -> Tensor: return ep.full_like(t, 5) @pytest.mark.parametrize("value", [1, -1, 2]) @compare_all def test_onehot_like(dummy: Tensor, value: float) -> Tensor: t = ep.arange(dummy, 18).float32().reshape((6, 3)) indices = ep.arange(t, 6) // 2 return ep.onehot_like(t, indices, value=value) @compare_all def test_zeros_scalar(t: Tensor) -> Tensor: return ep.zeros(t, 5) @compare_all def test_zeros_tuple(t: Tensor) -> Tensor: return ep.zeros(t, (2, 3)) @compare_all def test_ones_scalar(t: Tensor) -> Tensor: return ep.ones(t, 5) @compare_all def test_ones_tuple(t: Tensor) -> Tensor: return ep.ones(t, (2, 3)) @compare_all def test_full_scalar(t: Tensor) -> Tensor: return ep.full(t, 5, 4.0) @compare_all def test_full_tuple(t: Tensor) -> Tensor: return ep.full(t, (2, 3), 4.0) @compare_equal def test_uniform_scalar(t: Tensor) -> Shape: return ep.uniform(t, 5).shape @compare_equal def test_uniform_tuple(t: Tensor) -> Shape: return ep.uniform(t, (2, 3)).shape @compare_equal def test_normal_scalar(t: Tensor) -> Shape: return ep.normal(t, 5).shape @compare_equal def test_normal_tuple(t: Tensor) -> Shape: return ep.normal(t, (2, 3)).shape @compare_all def test_argsort(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 6).float32().reshape((2, 3)) return ep.argsort(t) @compare_all def test_sort(dummy: Tensor) -> Tensor: t = -ep.arange(dummy, 6).float32().reshape((2, 3)) return ep.sort(t) @compare_all def test_topk_values(dummy: Tensor) -> Tensor: t = (ep.arange(dummy, 27).reshape((3, 3, 3)) ** 2 * 10000 % 1234).float32() values, _ = ep.topk(t, 2) return values @compare_all def test_topk_indices(dummy: Tensor) -> Tensor: t = -(ep.arange(dummy, 27).reshape((3, 3, 3)) ** 2 * 10000 % 1234).float32() _, indices = ep.topk(t, 2) return indices @compare_all def test_transpose(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 8).float32().reshape((2, 4)) return ep.transpose(t) @compare_all def test_transpose_axes(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 60).float32().reshape((3, 4, 5)) return ep.transpose(t, axes=(1, 2, 0)) @compare_all def test_where(t: Tensor) -> Tensor: return ep.where(t >= 3, t, -t) @compare_all def test_where_first_scalar(t: Tensor) -> Tensor: return ep.where(t >= 3, 2, -t) @compare_all def test_where_second_scalar(t: Tensor) -> Tensor: return ep.where(t >= 3, t, 2) @compare_all def test_where_first_scalar64(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 60).float64().reshape((3, 4, 5)) return ep.where(t >= 3, 2, -t) @compare_all def test_where_second_scalar64(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 60).float64().reshape((3, 4, 5)) return ep.where(t >= 3, t, 2) @compare_all def test_where_both_scalars(t: Tensor) -> Tensor: return ep.where(t >= 3, 2, 5) @compare_all def test_tile(t: Tensor) -> Tensor: return ep.tile(t, (3,) * t.ndim) @compare_all def test_matmul(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 8).float32().reshape((2, 4)) return ep.matmul(t, t.T) @compare_all def test_matmul_operator(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 8).float32().reshape((2, 4)) return t @ t.T @compare_allclose(rtol=1e-6) def test_softmax(t: Tensor) -> Tensor: return ep.softmax(t) @compare_allclose(rtol=1e-5) def test_log_softmax(t: Tensor) -> Tensor: return ep.log_softmax(t) @compare_allclose def test_crossentropy(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 50).reshape((10, 5)).float32() t = t / t.max() return ep.crossentropy(t, t.argmax(axis=-1)) @pytest.mark.parametrize( "array, output", itertools.product( [ np.array([[1, 2], [3, 4]]), np.array([[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]), np.arange(100).reshape((10, 10)), ], ["sign", "logdet"], ), ids=map( lambda *l: "_".join(*l), itertools.product( ["matrix_finite", "stack_of_matrices", "matrix_infinite"], ["sign", "logdet"], ), ), ) @compare_allclose def test_slogdet(dummy: Tensor, array: Tensor, output: str) -> Tensor: a = ep.from_numpy(dummy, array).float32() outputs = dict() outputs["sign"], outputs["logdet"] = ep.slogdet(a) return outputs[output] @pytest.mark.parametrize("axis", [0, 1, -1]) @compare_all def test_stack(t1: Tensor, t2: Tensor, axis: int) -> Tensor: return ep.stack([t1, t2], axis=axis) @compare_all def test_concatenate_axis0(dummy: Tensor) -> Tensor: t1 = ep.arange(dummy, 12).float32().reshape((4, 3)) t2 = ep.arange(dummy, 20, 32, 2).float32().reshape((2, 3)) return ep.concatenate([t1, t2], axis=0) @compare_all def test_concatenate_axis1(dummy: Tensor) -> Tensor: t1 = ep.arange(dummy, 12).float32().reshape((3, 4)) t2 = ep.arange(dummy, 20, 32, 2).float32().reshape((3, 2)) return ep.concatenate([t1, t2], axis=1) @pytest.mark.parametrize("axis", [0, 1, -1]) @compare_all def test_expand_dims(t: Tensor, axis: int) -> Tensor: return ep.expand_dims(t, axis) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) @compare_all def test_squeeze(t: Tensor, axis: Optional[AxisAxes]) -> Tensor: t = t.expand_dims(axis=0).expand_dims(axis=1) return ep.squeeze(t, axis=axis) @compare_all def test_arange(dummy: Tensor) -> Tensor: return ep.arange(dummy, 6) @compare_all def test_arange_start(dummy: Tensor) -> Tensor: return ep.arange(dummy, 5, 10) @compare_all def test_arange_step(dummy: Tensor) -> Tensor: return ep.arange(dummy, 4, 8, 2) @compare_all def test_cumsum(t: Tensor) -> Tensor: return ep.cumsum(t) @compare_all def test_cumsum_axis(t: Tensor) -> Tensor: return ep.cumsum(t, axis=0) @compare_all def test_flip(t: Tensor) -> Tensor: return ep.flip(t) @compare_all def test_flip_axis(t: Tensor) -> Tensor: return ep.flip(t, axis=0) @pytest.mark.parametrize("indexing", ["ij", "xy"]) @pytest.mark.parametrize("i", [0, 1]) @compare_all def test_meshgrid_a(dummy: Tensor, indexing: str, i: int) -> Tensor: t1 = ep.arange(dummy, 5) t2 = ep.arange(dummy, 3) results = ep.meshgrid(t1, t2, indexing=indexing) assert len(results) == 2 return results[i] @pytest.mark.parametrize( "mode,value", [("constant", 0), ("constant", -2), ("reflect", 0)] ) @compare_all def test_pad(dummy: Tensor, mode: str, value: float) -> Tensor: t = ep.arange(dummy, 120).reshape((2, 3, 4, 5)).float32() return ep.pad(t, ((0, 0), (0, 0), (2, 3), (1, 2)), mode=mode, value=value) @compare_all def test_index_update_row(dummy: Tensor) -> Tensor: x = ep.ones(dummy, (3, 4)) return ep.index_update(x, ep.index[1], ep.ones(x, 4) * 66.0) @compare_all def test_index_update_row_scalar(dummy: Tensor) -> Tensor: x = ep.ones(dummy, (3, 4)) return ep.index_update(x, ep.index[1], 66.0) @compare_all def test_index_update_column(dummy: Tensor) -> Tensor: x = ep.ones(dummy, (3, 4)) return ep.index_update(x, ep.index[:, 1], ep.ones(x, 3) * 66.0) @compare_all def test_index_update_column_scalar(dummy: Tensor) -> Tensor: x = ep.ones(dummy, (3, 4)) return ep.index_update(x, ep.index[:, 1], 66.0) @compare_all def test_index_update_indices(dummy: Tensor) -> Tensor: x = ep.ones(dummy, (3, 4)) ind = ep.from_numpy(dummy, np.array([0, 1, 2, 1])) return ep.index_update(x, ep.index[ind, ep.arange(x, 4)], ep.ones(x, 4) * 33.0) @compare_all def test_index_update_indices_scalar(dummy: Tensor) -> Tensor: x = ep.ones(dummy, (3, 4)) ind = ep.from_numpy(dummy, np.array([0, 1, 2, 1])) return ep.index_update(x, ep.index[ind, ep.arange(x, 4)], 33.0) @compare_all def test_lt(t1: Tensor, t2: Tensor) -> Tensor: return t1 < t2 @compare_all def test_lt_scalar(t1: Tensor, t2: Tensor) -> Tensor: return 3 < t2 @compare_all def test_le(t1: Tensor, t2: Tensor) -> Tensor: return t1 <= t2 @compare_all def test_le_scalar(t1: Tensor, t2: Tensor) -> Tensor: return 3 <= t2 @compare_all def test_gt(t1: Tensor, t2: Tensor) -> Tensor: return t1 > t2 @compare_all def test_gt_scalar(t1: Tensor, t2: Tensor) -> Tensor: return 3 > t2 @compare_all def test_ge(t1: Tensor, t2: Tensor) -> Tensor: return t1 >= t2 @compare_all def test_ge_scalar(t1: Tensor, t2: Tensor) -> Tensor: return 3 >= t2 @compare_all def test_eq(t1: Tensor, t2: Tensor) -> Tensor: return t1 == t2 @compare_all def test_eq_scalar(t1: Tensor, t2: Tensor) -> Tensor: return cast(Tensor, 3 == t2) @compare_all def test_ne(t1: Tensor, t2: Tensor) -> Tensor: return t1 != t2 @compare_all def test_ne_scalar(t1: Tensor, t2: Tensor) -> Tensor: return cast(Tensor, 3 != t2) @compare_all def test_float_int_lt(t1: Tensor, t2int: Tensor) -> Tensor: return t1 < t2int @compare_all def test_float_int_le(t1: Tensor, t2int: Tensor) -> Tensor: return t1 <= t2int @compare_all def test_float_int_gt(t1: Tensor, t2int: Tensor) -> Tensor: return t1 > t2int @compare_all def test_float_int_ge(t1: Tensor, t2int: Tensor) -> Tensor: return t1 >= t2int @compare_all def test_float_int_eq(t1: Tensor, t2int: Tensor) -> Tensor: return t1 == t2int @compare_all def test_float_int_ne(t1: Tensor, t2int: Tensor) -> Tensor: return t1 != t2int @compare_all def test_int_float_lt(t1int: Tensor, t2: Tensor) -> Tensor: return t1int < t2 @compare_all def test_int_float_le(t1int: Tensor, t2: Tensor) -> Tensor: return t1int <= t2 @compare_all def test_int_float_gt(t1int: Tensor, t2: Tensor) -> Tensor: return t1int > t2 @compare_all def test_int_float_ge(t1int: Tensor, t2: Tensor) -> Tensor: return t1int >= t2 @compare_all def test_int_float_eq(t1int: Tensor, t2: Tensor) -> Tensor: return t1int == t2 @compare_all def test_int_float_ne(t1int: Tensor, t2: Tensor) -> Tensor: return t1int != t2 @compare_all def test_norms_l0(t: Tensor) -> Tensor: return t.norms.l0() @compare_all def test_norms_l1(t: Tensor) -> Tensor: return t.norms.l1() @compare_all def test_norms_l2(t: Tensor) -> Tensor: return t.norms.l2() @compare_all def test_norms_linf(t: Tensor) -> Tensor: return t.norms.linf() @compare_all def test_norms_lp(t: Tensor) -> Tensor: return t.norms.lp(2) @compare_all def test_norms_cache(t: Tensor) -> Tensor: return t.norms.l1() + t.norms.l2() eagerpy-0.30.0/tests/test_norms.py000066400000000000000000000052731410374365400171550ustar00rootroot00000000000000from typing import Optional import pytest from numpy.testing import assert_allclose from numpy.linalg import norm import numpy as np import eagerpy as ep from eagerpy import Tensor from eagerpy.norms import l0, l1, l2, linf, lp norms = {0: l0, 1: l1, 2: l2, ep.inf: linf} @pytest.fixture def x1d(dummy: Tensor) -> Tensor: return ep.arange(dummy, 10).float32() / 7.0 @pytest.fixture def x2d(dummy: Tensor) -> Tensor: return ep.arange(dummy, 12).float32().reshape((3, 4)) / 7.0 @pytest.fixture def x4d(dummy: Tensor) -> Tensor: return ep.arange(dummy, 2 * 3 * 4 * 5).float32().reshape((2, 3, 4, 5)) / 7.0 @pytest.mark.parametrize("p", [0, 1, 2, ep.inf]) def test_1d(x1d: Tensor, p: float) -> None: assert_allclose(lp(x1d, p).numpy(), norm(x1d.numpy(), ord=p)) assert_allclose(norms[p](x1d).numpy(), norm(x1d.numpy(), ord=p)) @pytest.mark.parametrize("p", [0, 1, 2, 3, 4, ep.inf]) @pytest.mark.parametrize("axis", [0, 1, -1]) @pytest.mark.parametrize("keepdims", [False, True]) def test_2d(x2d: Tensor, p: float, axis: int, keepdims: bool) -> None: assert isinstance(axis, int) # see test4d for the more general test assert_allclose( lp(x2d, p, axis=axis, keepdims=keepdims).numpy(), norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims), rtol=1e-6, ) if p not in norms: return assert_allclose( norms[p](x2d, axis=axis, keepdims=keepdims).numpy(), norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims), rtol=1e-6, ) @pytest.mark.parametrize("p", [0, 1, 2, 3, 4, ep.inf]) @pytest.mark.parametrize( "axis", [ None, 0, 1, 2, 3, -1, -2, -3, -4, (0, 1), (1, 2), (1, 3), (1, 2, 3), (0, 1, 3), (2, 1, 0), ], ) @pytest.mark.parametrize("keepdims", [False, True]) def test_4d( x4d: Tensor, p: float, axis: Optional[ep.types.AxisAxes], keepdims: bool ) -> None: actual = lp(x4d, p, axis=axis, keepdims=keepdims).numpy() # numpy does not support arbitrary axes (limited to vector and matrix norms) if axis is None: axes = tuple(range(x4d.ndim)) elif not isinstance(axis, tuple): axes = (axis,) else: axes = axis del axis axes = tuple(i % x4d.ndim for i in axes) x = x4d.numpy() other = tuple(i for i in range(x.ndim) if i not in axes) x = np.transpose(x, other + axes) x = x.reshape(x.shape[: len(other)] + (-1,)) desired = norm(x, ord=p, axis=-1) if keepdims: shape = tuple(1 if i in axes else x4d.shape[i] for i in range(x4d.ndim)) desired = desired.reshape(shape) assert_allclose(actual, desired, rtol=1e-6)